Skip to content

Commit de4de72

Browse files
authored
Merge pull request #52 from KumarLabJax/frame_level
Frame level performance
2 parents b40ad37 + b46d865 commit de4de72

File tree

2 files changed

+315
-17
lines changed

2 files changed

+315
-17
lines changed

src/jabs_postprocess/compare_gt.py

Lines changed: 224 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,16 @@ def evaluate_ground_truth(
4141
filter_scan: List of filter (minimum duration in frames to consider real) values to test
4242
iou_thresholds: List of intersection over union thresholds to scan
4343
filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions)
44+
trim_time: Limit the duration in frames of videos for performance
45+
Returns:
46+
None, but saves the following files to results_folder:
47+
framewise_output: Output file to save the frame-level performance plot
4448
scan_output: Output file to save the filter scan performance plot
4549
bout_output: Output file to save the resulting bout performance plot
46-
trim_time: Limit the duration in frames of videos for performance
4750
ethogram_output: Output file to save the ethogram plot comparing GT and predictions
4851
scan_csv_output: Output file to save the scan performance data as CSV
4952
"""
50-
ouput_paths = generate_output_paths(results_folder)
53+
output_paths = generate_output_paths(results_folder)
5154

5255
# Set default values if not provided
5356
stitch_scan = stitch_scan or np.arange(5, 46, 5).tolist()
@@ -122,6 +125,14 @@ def evaluate_ground_truth(
122125
pred_df["is_gt"] = False
123126
all_annotations = pd.concat([gt_df, pred_df])
124127

128+
# Generate frame-level performance plot
129+
framewise_plot = generate_framewise_performance_plot(gt_df, pred_df)
130+
if output_paths["framewise_plot"] is not None:
131+
framewise_plot.save(output_paths["framewise_plot"], height=6, width=12, dpi=300)
132+
logging.info(
133+
f"Frame-level performance plot saved to {output_paths['framewise_plot']}"
134+
)
135+
125136
# We only want the positive examples for performance evaluation
126137
# (but for ethogram plotting later, we'll use the full all_annotations)
127138
performance_annotations = all_annotations[
@@ -150,9 +161,9 @@ def evaluate_ground_truth(
150161
"No performance data to analyze. Ensure that the ground truth and predictions are correctly formatted and contain valid bouts."
151162
)
152163

153-
if ouput_paths["scan_csv"] is not None:
154-
performance_df.to_csv(ouput_paths["scan_csv"], index=False)
155-
logging.info(f"Scan performance data saved to {ouput_paths['scan_csv']}")
164+
if output_paths["scan_csv"] is not None:
165+
performance_df.to_csv(output_paths["scan_csv"], index=False)
166+
logging.info(f"Scan performance data saved to {output_paths['scan_csv']}")
156167

157168
_melted_df = pd.melt(performance_df, id_vars=["threshold", "stitch", "filter"])
158169

@@ -172,8 +183,8 @@ def evaluate_ground_truth(
172183
+ p9.theme_bw()
173184
+ p9.labs(title=f"No performance data for {middle_threshold} IoU")
174185
)
175-
if ouput_paths["scan_plot"]:
176-
plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300)
186+
if output_paths["scan_plot"]:
187+
plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300)
177188
# Create default winning filters with first values from scan parameters
178189
winning_filters = pd.DataFrame(
179190
{
@@ -231,8 +242,8 @@ def evaluate_ground_truth(
231242
+ p9.scale_fill_continuous(na_value=0)
232243
)
233244

234-
if ouput_paths["scan_plot"]:
235-
plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300)
245+
if output_paths["scan_plot"]:
246+
plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300)
236247

237248
# Handle case where all f1_plot values are NaN or empty
238249
if subset_df["f1_plot"].isna().all() or len(subset_df) == 0:
@@ -254,9 +265,9 @@ def evaluate_ground_truth(
254265
).T.reset_index(drop=True)[["stitch", "filter"]]
255266

256267
winning_bout_df = pd.merge(performance_df, winning_filters, on=["stitch", "filter"])
257-
if ouput_paths["bout_csv"] is not None:
258-
winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False)
259-
logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}")
268+
if output_paths["bout_csv"] is not None:
269+
winning_bout_df.to_csv(output_paths["bout_csv"], index=False)
270+
logging.info(f"Bout performance data saved to {output_paths['bout_csv']}")
260271

261272
melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"])
262273

@@ -268,9 +279,9 @@ def evaluate_ground_truth(
268279
+ p9.geom_line()
269280
+ p9.theme_bw()
270281
+ p9.scale_y_continuous(limits=(0, 1))
271-
).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300)
282+
).save(output_paths["bout_plot"], height=6, width=12, dpi=300)
272283

273-
if ouput_paths["ethogram"] is not None:
284+
if output_paths["ethogram"] is not None:
274285
# Prepare data for ethogram plot
275286
# Use all_annotations to include both behavior (1) and not-behavior (0) states
276287
plot_df = all_annotations.copy()
@@ -327,14 +338,14 @@ def evaluate_ground_truth(
327338
)
328339
# Adjust height based on the number of unique animal-video combinations
329340
ethogram_plot.save(
330-
ouput_paths["ethogram"],
341+
output_paths["ethogram"],
331342
height=1.5 * num_unique_combos + 2,
332343
width=12,
333344
dpi=300,
334345
limitsize=False,
335346
verbose=False,
336347
)
337-
logging.info(f"Ethogram plot saved to {ouput_paths['ethogram']}")
348+
logging.info(f"Ethogram plot saved to {output_paths['ethogram']}")
338349
else:
339350
logger.warning(
340351
f"No behavior instances found for behavior {behavior} after filtering for ethogram."
@@ -480,4 +491,201 @@ def generate_output_paths(results_folder: Path):
480491
"ethogram": results_folder / "ethogram.png",
481492
"scan_plot": results_folder / "scan_performance.png",
482493
"bout_plot": results_folder / "bout_performance.png",
494+
"framewise_plot": results_folder / "framewise_performance.png",
483495
}
496+
497+
498+
def _expand_intervals_to_frames(df):
499+
"""Expand behavior intervals into per-frame rows."""
500+
expanded = df.copy()
501+
expanded["frame"] = expanded.apply(
502+
lambda row: range(row["start"], row["start"] + row["duration"]), axis=1
503+
)
504+
expanded = expanded.explode("frame")
505+
expanded = expanded.sort_values(by=["animal_idx", "frame"])
506+
return expanded
507+
508+
509+
def _compute_framewise_confusion(gt_df, pred_df):
510+
"""Compute frame-level confusion counts (TP, TN, FP, FN) per video.
511+
512+
Args:
513+
gt_df (pd.DataFrame): Ground truth intervals with columns
514+
['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
515+
pred_df (pd.DataFrame): Prediction intervals with the same structure.
516+
517+
Returns:
518+
pd.DataFrame: Confusion matrix counts per video with columns
519+
['video_name', 'TP', 'TN', 'FP', 'FN'].
520+
"""
521+
522+
# Expand ground truth and predictions into frame-level data
523+
gt_frames = _expand_intervals_to_frames(gt_df)
524+
pred_frames = _expand_intervals_to_frames(pred_df)
525+
526+
# Merge to align predictions and ground truth per frame
527+
framewise = pd.merge(
528+
gt_frames,
529+
pred_frames,
530+
on=["video_name", "animal_idx", "frame"],
531+
how="left",
532+
suffixes=("_gt", "_pred"),
533+
)
534+
535+
# Compute confusion counts per video
536+
confusion_counts = (
537+
framewise.groupby("video_name")
538+
.apply(
539+
lambda x: pd.Series(
540+
{
541+
"TP": (
542+
(x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 1)
543+
).sum(),
544+
"TN": (
545+
(x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 0)
546+
).sum(),
547+
"FP": (
548+
(x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 1)
549+
).sum(),
550+
"FN": (
551+
(x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 0)
552+
).sum(),
553+
}
554+
),
555+
include_groups=False,
556+
)
557+
.reset_index()
558+
)
559+
560+
return confusion_counts
561+
562+
563+
def _find_outliers(melted_df: pd.DataFrame):
564+
"""
565+
Return rows flagged as outliers per metric using the IQR rule.
566+
567+
Args:
568+
melted_df: long-form DataFrame with at least 'metric' and 'value' columns.
569+
570+
Returns:
571+
DataFrame containing the outliers rows from the input DataFrame.
572+
Returns an empty DataFrame with the same columns if no outliers found.
573+
"""
574+
outliers = []
575+
for metric in melted_df["metric"].unique():
576+
values = melted_df.loc[melted_df["metric"] == metric, "value"]
577+
q1 = values.quantile(0.25)
578+
q3 = values.quantile(0.75)
579+
iqr = q3 - q1
580+
lower_bound = q1 - 1.5 * iqr
581+
upper_bound = q3 + 1.5 * iqr
582+
outliers_df = melted_df[
583+
(melted_df["metric"] == metric)
584+
& ((melted_df["value"] < lower_bound) | (melted_df["value"] > upper_bound))
585+
]
586+
outliers.append(outliers_df)
587+
588+
outliers = (
589+
pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns)
590+
)
591+
592+
return outliers
593+
594+
595+
def generate_framewise_performance_plot(gt_df: pd.DataFrame, pred_df: pd.DataFrame):
596+
"""
597+
Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals.
598+
599+
This function:
600+
1. Expands each interval in `gt_df` and `pred_df` to per-frame annotations.
601+
2. Computes per-video confusion counts (TP, TN, FP, FN).
602+
3. Calculates precision, recall, F1 score, and accuracy for each video.
603+
4. Produces a boxplot with jitter showing the distribution of these metrics.
604+
5. Adds an overall summary in the plot subtitle.
605+
606+
Args:
607+
gt_df (pd.DataFrame): Ground truth intervals with columns
608+
['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
609+
pred_df (pd.DataFrame): Prediction intervals with the same structure.
610+
611+
Returns:
612+
plotnine.ggplot: A ggplot object containing the frame-level performance visualization.
613+
"""
614+
# Compute framewise confusion counts
615+
confusion_counts = _compute_framewise_confusion(gt_df, pred_df)
616+
confusion_counts["frame_total"] = (
617+
confusion_counts["TP"]
618+
+ confusion_counts["TN"]
619+
+ confusion_counts["FP"]
620+
+ confusion_counts["FN"]
621+
)
622+
623+
# Compute per-video metrics
624+
confusion_counts["precision"] = confusion_counts["TP"] / (
625+
confusion_counts["TP"] + confusion_counts["FP"]
626+
)
627+
confusion_counts["recall"] = confusion_counts["TP"] / (
628+
confusion_counts["TP"] + confusion_counts["FN"]
629+
)
630+
confusion_counts["f1_score"] = (
631+
2
632+
* (confusion_counts["precision"] * confusion_counts["recall"])
633+
/ (confusion_counts["precision"] + confusion_counts["recall"])
634+
)
635+
confusion_counts["accuracy"] = (
636+
confusion_counts["TP"] + confusion_counts["TN"]
637+
) / confusion_counts["frame_total"]
638+
639+
# Compute overall (global) metrics
640+
totals = confusion_counts[["TP", "TN", "FP", "FN"]].sum()
641+
overall_metrics = {
642+
"precision": totals["TP"] / (totals["TP"] + totals["FP"]),
643+
"recall": totals["TP"] / (totals["TP"] + totals["FN"]),
644+
"accuracy": (totals["TP"] + totals["TN"])
645+
/ (totals["TP"] + totals["TN"] + totals["FP"] + totals["FN"]),
646+
}
647+
overall_metrics["f1_score"] = (
648+
2
649+
* (overall_metrics["precision"] * overall_metrics["recall"])
650+
/ (overall_metrics["precision"] + overall_metrics["recall"])
651+
)
652+
653+
# Melt into long format for plotting
654+
melted_df = pd.melt(
655+
confusion_counts,
656+
id_vars=["video_name", "frame_total"],
657+
value_vars=["precision", "recall", "f1_score", "accuracy"],
658+
var_name="metric",
659+
value_name="value",
660+
)
661+
662+
outliers = _find_outliers(melted_df)
663+
# Generate plot
664+
subtitle_text = (
665+
f"Precision: {overall_metrics['precision']:.2f}, "
666+
f"Recall: {overall_metrics['recall']:.2f}, "
667+
f"F1: {overall_metrics['f1_score']:.2f}, "
668+
f"Accuracy: {overall_metrics['accuracy']:.2f}"
669+
)
670+
671+
plot = (
672+
p9.ggplot(melted_df, p9.aes(x="metric", y="value"))
673+
+ p9.geom_boxplot(outlier_shape=None, fill="lightblue", alpha=0.7)
674+
+ p9.geom_jitter(p9.aes(color="frame_total"), width=0.05, height=0)
675+
+ p9.geom_text(
676+
p9.aes(label="video_name"), data=outliers, ha="left", nudge_x=0.1
677+
)
678+
+ p9.labs(
679+
title="Frame-level Performance Metrics",
680+
y="Score",
681+
x="Metric",
682+
subtitle=subtitle_text,
683+
)
684+
+ p9.theme_bw()
685+
+ p9.theme(
686+
plot_title=p9.element_text(ha="center"), # Center the main title
687+
plot_subtitle=p9.element_text(ha="center"), # Center the subtitle too
688+
)
689+
)
690+
691+
return plot

0 commit comments

Comments
 (0)