Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 199 additions & 1 deletion src/jabs_postprocess/compare_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ def evaluate_ground_truth(
filter_scan: List of filter (minimum duration in frames to consider real) values to test
iou_thresholds: List of intersection over union thresholds to scan
filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions)
trim_time: Limit the duration in frames of videos for performance
Returns:
None, but saves the following files to results_folder:
framewise_output: Output file to save the frame-level performance plot
scan_output: Output file to save the filter scan performance plot
bout_output: Output file to save the resulting bout performance plot
trim_time: Limit the duration in frames of videos for performance
ethogram_output: Output file to save the ethogram plot comparing GT and predictions
scan_csv_output: Output file to save the scan performance data as CSV
"""
Expand Down Expand Up @@ -122,6 +125,14 @@ def evaluate_ground_truth(
pred_df["is_gt"] = False
all_annotations = pd.concat([gt_df, pred_df])

# Generate frame-level performance plot
framewise_plot = generate_framewise_performance_plot(gt_df, pred_df)
if ouput_paths["framewise_plot"] is not None:
framewise_plot.save(ouput_paths["framewise_plot"], height=6, width=12, dpi=300)
logging.info(
f"Frame-level performance plot saved to {ouput_paths['framewise_plot']}"
)

# We only want the positive examples for performance evaluation
# (but for ethogram plotting later, we'll use the full all_annotations)
performance_annotations = all_annotations[
Expand Down Expand Up @@ -480,4 +491,191 @@ def generate_output_paths(results_folder: Path):
"ethogram": results_folder / "ethogram.png",
"scan_plot": results_folder / "scan_performance.png",
"bout_plot": results_folder / "bout_performance.png",
"framewise_plot": results_folder / "framewise_performance.png",
}


def _compute_framewise_confusion(gt_df, pred_df):
"""Compute frame-level confusion counts (TP, TN, FP, FN) per video.

Args:
gt_df (pd.DataFrame): Ground truth intervals with columns
['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
pred_df (pd.DataFrame): Prediction intervals with the same structure.

Returns:
pd.DataFrame: Confusion matrix counts per video with columns
['video_name', 'TP', 'TN', 'FP', 'FN'].
"""

def expand_intervals_to_frames(df):
"""Expand behavior intervals into per-frame rows."""
expanded = df.copy()
expanded["frame"] = expanded.apply(
lambda row: range(row["start"], row["start"] + row["duration"]), axis=1
)
expanded = expanded.explode("frame")
expanded = expanded.sort_values(by=["animal_idx", "frame"])
return expanded

# Expand ground truth and predictions into frame-level data
gt_frames = expand_intervals_to_frames(gt_df)
pred_frames = expand_intervals_to_frames(pred_df)

# Merge to align predictions and ground truth per frame
framewise = pd.merge(
gt_frames,
pred_frames,
on=["video_name", "animal_idx", "frame"],
how="left",
suffixes=("_gt", "_pred"),
)

# Compute confusion counts per video
confusion_counts = (
framewise.groupby("video_name")
.apply(
lambda x: pd.Series(
{
"TP": (
(x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 1)
).sum(),
"TN": (
(x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 0)
).sum(),
"FP": (
(x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 1)
).sum(),
"FN": (
(x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 0)
).sum(),
}
),
include_groups=False,
)
.reset_index()
)

return confusion_counts


def _find_outliers(melted_df: pd.DataFrame):
# Identify outliers per metric using IQR rule
outliers = []
for metric in melted_df["metric"].unique():
values = melted_df.loc[melted_df["metric"] == metric, "value"]
q1 = values.quantile(0.25)
q3 = values.quantile(0.75)
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outliers_df = melted_df[
(melted_df["metric"] == metric)
& ((melted_df["value"] < lower_bound) | (melted_df["value"] > upper_bound))
]
outliers.append(outliers_df)

outliers = (
pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns)
)

return outliers


def generate_framewise_performance_plot(gt_df: pd.DataFrame, pred_df: pd.DataFrame):
"""
Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals.

This function:
1. Expands each interval in `gt_df` and `pred_df` to per-frame annotations.
2. Computes per-video confusion counts (TP, TN, FP, FN).
3. Calculates precision, recall, F1 score, and accuracy for each video.
4. Produces a boxplot with jitter showing the distribution of these metrics.
5. Adds an overall summary in the plot subtitle.

Args:
gt_df (pd.DataFrame): Ground truth intervals with columns
['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
pred_df (pd.DataFrame): Prediction intervals with the same structure.

Returns:
plotnine.ggplot: A ggplot object containing the frame-level performance visualization.
"""
# Compute framewise confusion counts
confusion_counts = _compute_framewise_confusion(gt_df, pred_df)
confusion_counts["frame_total"] = (
confusion_counts["TP"]
+ confusion_counts["TN"]
+ confusion_counts["FP"]
+ confusion_counts["FN"]
)

# Compute per-video metrics
confusion_counts["precision"] = confusion_counts["TP"] / (
confusion_counts["TP"] + confusion_counts["FP"]
)
confusion_counts["recall"] = confusion_counts["TP"] / (
confusion_counts["TP"] + confusion_counts["FN"]
)
confusion_counts["f1_score"] = (
2
* (confusion_counts["precision"] * confusion_counts["recall"])
/ (confusion_counts["precision"] + confusion_counts["recall"])
)
confusion_counts["accuracy"] = (
confusion_counts["TP"] + confusion_counts["TN"]
) / confusion_counts["frame_total"]

# Compute overall (global) metrics
totals = confusion_counts[["TP", "TN", "FP", "FN"]].sum()
overall_metrics = {
"precision": totals["TP"] / (totals["TP"] + totals["FP"]),
"recall": totals["TP"] / (totals["TP"] + totals["FN"]),
"accuracy": (totals["TP"] + totals["TN"])
/ (totals["TP"] + totals["TN"] + totals["FP"] + totals["FN"]),
}
overall_metrics["f1_score"] = (
2
* (overall_metrics["precision"] * overall_metrics["recall"])
/ (overall_metrics["precision"] + overall_metrics["recall"])
)

# Melt into long format for plotting
melted_df = pd.melt(
confusion_counts,
id_vars=["video_name", "frame_total"],
value_vars=["precision", "recall", "f1_score", "accuracy"],
var_name="metric",
value_name="value",
)

outliers = _find_outliers(melted_df)
# Generate plot
subtitle_text = (
f"Precision: {overall_metrics['precision']:.2f}, "
f"Recall: {overall_metrics['recall']:.2f}, "
f"F1: {overall_metrics['f1_score']:.2f}, "
f"Accuracy: {overall_metrics['accuracy']:.2f}"
)

plot = (
p9.ggplot(melted_df, p9.aes(x="metric", y="value"))
+ p9.geom_boxplot(outlier_shape=None, fill="lightblue", alpha=0.7)
+ p9.geom_jitter(p9.aes(color="frame_total"), width=0.05, height=0)
+ p9.geom_text(
p9.aes(label="video_name"), data=outliers, ha="left", nudge_x=0.1
)
+ p9.labs(
title="Frame-level Performance Metrics",
y="Score",
x="Metric",
subtitle=subtitle_text,
)
+ p9.theme_bw()
+ p9.theme(
plot_title=p9.element_text(ha="center"), # Center the main title
plot_subtitle=p9.element_text(ha="center"), # Center the subtitle too
)
)

return plot
95 changes: 94 additions & 1 deletion tests/test_compare_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@

import numpy as np
import pandas as pd
import plotnine as p9
import pytest

from jabs_postprocess.compare_gt import evaluate_ground_truth, generate_iou_scan
from jabs_postprocess.compare_gt import (
evaluate_ground_truth,
generate_iou_scan,
generate_framewise_performance_plot,
)
from jabs_postprocess.utils.project_utils import (
Bouts,
)
Expand Down Expand Up @@ -672,3 +677,91 @@ def test_generate_iou_scan_metrics_calculation(mock_metrics, expected_result):
assert np.isnan(row[metric])
else:
assert round(row[metric], 3) == round(expected, 3)


@pytest.fixture
def sample_data():
"""Create small sample GT and prediction DataFrames for testing."""
gt_df = pd.DataFrame(
{
"video_name": ["video1", "video1", "video2"],
"animal_idx": [0, 1, 0],
"start": [0, 5, 0],
"duration": [5, 5, 10],
"is_behavior": [1, 0, 1],
}
)

pred_df = pd.DataFrame(
{
"video_name": ["video1", "video1", "video2"],
"animal_idx": [0, 1, 0],
"start": [0, 5, 0],
"duration": [5, 5, 10],
"is_behavior": [1, 0, 0],
}
)

return gt_df, pred_df


def test_generate_plot_runs(sample_data):
"""Test that the plot function runs and returns a ggplot object."""
gt_df, pred_df = sample_data
plot = generate_framewise_performance_plot(gt_df, pred_df)
# Check that the returned object is a ggplot
assert isinstance(plot, p9.ggplot)


def test_plot_metrics(sample_data):
"""Test that generate_framewise_performance_plot correctly handles NaNs."""
gt_df, pred_df = sample_data

plot = generate_framewise_performance_plot(gt_df, pred_df)
df = plot.data.sort_values(["video_name", "metric"]).reset_index(drop=True)

# Manually compute expected metrics
expected = []
# Video 1: Perfect prediction
expected.append(
{
"video_name": "video1",
"precision": 1.0,
"recall": 1.0,
"f1_score": 1.0,
"accuracy": 1.0,
}
)
# Video 2: All wrong
expected.append(
{
"video_name": "video2",
"precision": float("nan"),
"recall": 0.0,
"f1_score": float("nan"),
"accuracy": 0.0,
}
)

expected_df = pd.DataFrame(expected)
expected_melted = (
pd.melt(
expected_df,
id_vars=["video_name"],
value_vars=["precision", "recall", "f1_score", "accuracy"],
var_name="metric",
value_name="value",
)
.sort_values(["video_name", "metric"])
.reset_index(drop=True)
)

# Compare numeric values, treating NaNs as equal
for a, b in zip(df["value"], expected_melted["value"]):
if pd.isna(a) and pd.isna(b):
continue
else:
assert abs(a - b) < 1e-6


# %%