@@ -41,13 +41,16 @@ def evaluate_ground_truth(
4141 filter_scan: List of filter (minimum duration in frames to consider real) values to test
4242 iou_thresholds: List of intersection over union thresholds to scan
4343 filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions)
44+ trim_time: Limit the duration in frames of videos for performance
45+ Returns:
46+ None, but saves the following files to results_folder:
47+ framewise_output: Output file to save the frame-level performance plot
4448 scan_output: Output file to save the filter scan performance plot
4549 bout_output: Output file to save the resulting bout performance plot
46- trim_time: Limit the duration in frames of videos for performance
4750 ethogram_output: Output file to save the ethogram plot comparing GT and predictions
4851 scan_csv_output: Output file to save the scan performance data as CSV
4952 """
50- ouput_paths = generate_output_paths (results_folder )
53+ output_paths = generate_output_paths (results_folder )
5154
5255 # Set default values if not provided
5356 stitch_scan = stitch_scan or np .arange (5 , 46 , 5 ).tolist ()
@@ -122,6 +125,14 @@ def evaluate_ground_truth(
122125 pred_df ["is_gt" ] = False
123126 all_annotations = pd .concat ([gt_df , pred_df ])
124127
128+ # Generate frame-level performance plot
129+ framewise_plot = generate_framewise_performance_plot (gt_df , pred_df )
130+ if output_paths ["framewise_plot" ] is not None :
131+ framewise_plot .save (output_paths ["framewise_plot" ], height = 6 , width = 12 , dpi = 300 )
132+ logging .info (
133+ f"Frame-level performance plot saved to { output_paths ['framewise_plot' ]} "
134+ )
135+
125136 # We only want the positive examples for performance evaluation
126137 # (but for ethogram plotting later, we'll use the full all_annotations)
127138 performance_annotations = all_annotations [
@@ -150,9 +161,9 @@ def evaluate_ground_truth(
150161 "No performance data to analyze. Ensure that the ground truth and predictions are correctly formatted and contain valid bouts."
151162 )
152163
153- if ouput_paths ["scan_csv" ] is not None :
154- performance_df .to_csv (ouput_paths ["scan_csv" ], index = False )
155- logging .info (f"Scan performance data saved to { ouput_paths ['scan_csv' ]} " )
164+ if output_paths ["scan_csv" ] is not None :
165+ performance_df .to_csv (output_paths ["scan_csv" ], index = False )
166+ logging .info (f"Scan performance data saved to { output_paths ['scan_csv' ]} " )
156167
157168 _melted_df = pd .melt (performance_df , id_vars = ["threshold" , "stitch" , "filter" ])
158169
@@ -172,8 +183,8 @@ def evaluate_ground_truth(
172183 + p9 .theme_bw ()
173184 + p9 .labs (title = f"No performance data for { middle_threshold } IoU" )
174185 )
175- if ouput_paths ["scan_plot" ]:
176- plot .save (ouput_paths ["scan_plot" ], height = 6 , width = 12 , dpi = 300 )
186+ if output_paths ["scan_plot" ]:
187+ plot .save (output_paths ["scan_plot" ], height = 6 , width = 12 , dpi = 300 )
177188 # Create default winning filters with first values from scan parameters
178189 winning_filters = pd .DataFrame (
179190 {
@@ -231,8 +242,8 @@ def evaluate_ground_truth(
231242 + p9 .scale_fill_continuous (na_value = 0 )
232243 )
233244
234- if ouput_paths ["scan_plot" ]:
235- plot .save (ouput_paths ["scan_plot" ], height = 6 , width = 12 , dpi = 300 )
245+ if output_paths ["scan_plot" ]:
246+ plot .save (output_paths ["scan_plot" ], height = 6 , width = 12 , dpi = 300 )
236247
237248 # Handle case where all f1_plot values are NaN or empty
238249 if subset_df ["f1_plot" ].isna ().all () or len (subset_df ) == 0 :
@@ -254,9 +265,9 @@ def evaluate_ground_truth(
254265 ).T .reset_index (drop = True )[["stitch" , "filter" ]]
255266
256267 winning_bout_df = pd .merge (performance_df , winning_filters , on = ["stitch" , "filter" ])
257- if ouput_paths ["bout_csv" ] is not None :
258- winning_bout_df .to_csv (ouput_paths ["bout_csv" ], index = False )
259- logging .info (f"Bout performance data saved to { ouput_paths ['bout_csv' ]} " )
268+ if output_paths ["bout_csv" ] is not None :
269+ winning_bout_df .to_csv (output_paths ["bout_csv" ], index = False )
270+ logging .info (f"Bout performance data saved to { output_paths ['bout_csv' ]} " )
260271
261272 melted_winning = pd .melt (winning_bout_df , id_vars = ["threshold" , "stitch" , "filter" ])
262273
@@ -268,9 +279,9 @@ def evaluate_ground_truth(
268279 + p9 .geom_line ()
269280 + p9 .theme_bw ()
270281 + p9 .scale_y_continuous (limits = (0 , 1 ))
271- ).save (ouput_paths ["bout_plot" ], height = 6 , width = 12 , dpi = 300 )
282+ ).save (output_paths ["bout_plot" ], height = 6 , width = 12 , dpi = 300 )
272283
273- if ouput_paths ["ethogram" ] is not None :
284+ if output_paths ["ethogram" ] is not None :
274285 # Prepare data for ethogram plot
275286 # Use all_annotations to include both behavior (1) and not-behavior (0) states
276287 plot_df = all_annotations .copy ()
@@ -327,14 +338,14 @@ def evaluate_ground_truth(
327338 )
328339 # Adjust height based on the number of unique animal-video combinations
329340 ethogram_plot .save (
330- ouput_paths ["ethogram" ],
341+ output_paths ["ethogram" ],
331342 height = 1.5 * num_unique_combos + 2 ,
332343 width = 12 ,
333344 dpi = 300 ,
334345 limitsize = False ,
335346 verbose = False ,
336347 )
337- logging .info (f"Ethogram plot saved to { ouput_paths ['ethogram' ]} " )
348+ logging .info (f"Ethogram plot saved to { output_paths ['ethogram' ]} " )
338349 else :
339350 logger .warning (
340351 f"No behavior instances found for behavior { behavior } after filtering for ethogram."
@@ -480,4 +491,201 @@ def generate_output_paths(results_folder: Path):
480491 "ethogram" : results_folder / "ethogram.png" ,
481492 "scan_plot" : results_folder / "scan_performance.png" ,
482493 "bout_plot" : results_folder / "bout_performance.png" ,
494+ "framewise_plot" : results_folder / "framewise_performance.png" ,
483495 }
496+
497+
498+ def _expand_intervals_to_frames (df ):
499+ """Expand behavior intervals into per-frame rows."""
500+ expanded = df .copy ()
501+ expanded ["frame" ] = expanded .apply (
502+ lambda row : range (row ["start" ], row ["start" ] + row ["duration" ]), axis = 1
503+ )
504+ expanded = expanded .explode ("frame" )
505+ expanded = expanded .sort_values (by = ["animal_idx" , "frame" ])
506+ return expanded
507+
508+
509+ def _compute_framewise_confusion (gt_df , pred_df ):
510+ """Compute frame-level confusion counts (TP, TN, FP, FN) per video.
511+
512+ Args:
513+ gt_df (pd.DataFrame): Ground truth intervals with columns
514+ ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
515+ pred_df (pd.DataFrame): Prediction intervals with the same structure.
516+
517+ Returns:
518+ pd.DataFrame: Confusion matrix counts per video with columns
519+ ['video_name', 'TP', 'TN', 'FP', 'FN'].
520+ """
521+
522+ # Expand ground truth and predictions into frame-level data
523+ gt_frames = _expand_intervals_to_frames (gt_df )
524+ pred_frames = _expand_intervals_to_frames (pred_df )
525+
526+ # Merge to align predictions and ground truth per frame
527+ framewise = pd .merge (
528+ gt_frames ,
529+ pred_frames ,
530+ on = ["video_name" , "animal_idx" , "frame" ],
531+ how = "left" ,
532+ suffixes = ("_gt" , "_pred" ),
533+ )
534+
535+ # Compute confusion counts per video
536+ confusion_counts = (
537+ framewise .groupby ("video_name" )
538+ .apply (
539+ lambda x : pd .Series (
540+ {
541+ "TP" : (
542+ (x ["is_behavior_gt" ] == 1 ) & (x ["is_behavior_pred" ] == 1 )
543+ ).sum (),
544+ "TN" : (
545+ (x ["is_behavior_gt" ] == 0 ) & (x ["is_behavior_pred" ] == 0 )
546+ ).sum (),
547+ "FP" : (
548+ (x ["is_behavior_gt" ] == 0 ) & (x ["is_behavior_pred" ] == 1 )
549+ ).sum (),
550+ "FN" : (
551+ (x ["is_behavior_gt" ] == 1 ) & (x ["is_behavior_pred" ] == 0 )
552+ ).sum (),
553+ }
554+ ),
555+ include_groups = False ,
556+ )
557+ .reset_index ()
558+ )
559+
560+ return confusion_counts
561+
562+
563+ def _find_outliers (melted_df : pd .DataFrame ):
564+ """
565+ Return rows flagged as outliers per metric using the IQR rule.
566+
567+ Args:
568+ melted_df: long-form DataFrame with at least 'metric' and 'value' columns.
569+
570+ Returns:
571+ DataFrame containing the outliers rows from the input DataFrame.
572+ Returns an empty DataFrame with the same columns if no outliers found.
573+ """
574+ outliers = []
575+ for metric in melted_df ["metric" ].unique ():
576+ values = melted_df .loc [melted_df ["metric" ] == metric , "value" ]
577+ q1 = values .quantile (0.25 )
578+ q3 = values .quantile (0.75 )
579+ iqr = q3 - q1
580+ lower_bound = q1 - 1.5 * iqr
581+ upper_bound = q3 + 1.5 * iqr
582+ outliers_df = melted_df [
583+ (melted_df ["metric" ] == metric )
584+ & ((melted_df ["value" ] < lower_bound ) | (melted_df ["value" ] > upper_bound ))
585+ ]
586+ outliers .append (outliers_df )
587+
588+ outliers = (
589+ pd .concat (outliers ) if outliers else pd .DataFrame (columns = melted_df .columns )
590+ )
591+
592+ return outliers
593+
594+
595+ def generate_framewise_performance_plot (gt_df : pd .DataFrame , pred_df : pd .DataFrame ):
596+ """
597+ Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals.
598+
599+ This function:
600+ 1. Expands each interval in `gt_df` and `pred_df` to per-frame annotations.
601+ 2. Computes per-video confusion counts (TP, TN, FP, FN).
602+ 3. Calculates precision, recall, F1 score, and accuracy for each video.
603+ 4. Produces a boxplot with jitter showing the distribution of these metrics.
604+ 5. Adds an overall summary in the plot subtitle.
605+
606+ Args:
607+ gt_df (pd.DataFrame): Ground truth intervals with columns
608+ ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior'].
609+ pred_df (pd.DataFrame): Prediction intervals with the same structure.
610+
611+ Returns:
612+ plotnine.ggplot: A ggplot object containing the frame-level performance visualization.
613+ """
614+ # Compute framewise confusion counts
615+ confusion_counts = _compute_framewise_confusion (gt_df , pred_df )
616+ confusion_counts ["frame_total" ] = (
617+ confusion_counts ["TP" ]
618+ + confusion_counts ["TN" ]
619+ + confusion_counts ["FP" ]
620+ + confusion_counts ["FN" ]
621+ )
622+
623+ # Compute per-video metrics
624+ confusion_counts ["precision" ] = confusion_counts ["TP" ] / (
625+ confusion_counts ["TP" ] + confusion_counts ["FP" ]
626+ )
627+ confusion_counts ["recall" ] = confusion_counts ["TP" ] / (
628+ confusion_counts ["TP" ] + confusion_counts ["FN" ]
629+ )
630+ confusion_counts ["f1_score" ] = (
631+ 2
632+ * (confusion_counts ["precision" ] * confusion_counts ["recall" ])
633+ / (confusion_counts ["precision" ] + confusion_counts ["recall" ])
634+ )
635+ confusion_counts ["accuracy" ] = (
636+ confusion_counts ["TP" ] + confusion_counts ["TN" ]
637+ ) / confusion_counts ["frame_total" ]
638+
639+ # Compute overall (global) metrics
640+ totals = confusion_counts [["TP" , "TN" , "FP" , "FN" ]].sum ()
641+ overall_metrics = {
642+ "precision" : totals ["TP" ] / (totals ["TP" ] + totals ["FP" ]),
643+ "recall" : totals ["TP" ] / (totals ["TP" ] + totals ["FN" ]),
644+ "accuracy" : (totals ["TP" ] + totals ["TN" ])
645+ / (totals ["TP" ] + totals ["TN" ] + totals ["FP" ] + totals ["FN" ]),
646+ }
647+ overall_metrics ["f1_score" ] = (
648+ 2
649+ * (overall_metrics ["precision" ] * overall_metrics ["recall" ])
650+ / (overall_metrics ["precision" ] + overall_metrics ["recall" ])
651+ )
652+
653+ # Melt into long format for plotting
654+ melted_df = pd .melt (
655+ confusion_counts ,
656+ id_vars = ["video_name" , "frame_total" ],
657+ value_vars = ["precision" , "recall" , "f1_score" , "accuracy" ],
658+ var_name = "metric" ,
659+ value_name = "value" ,
660+ )
661+
662+ outliers = _find_outliers (melted_df )
663+ # Generate plot
664+ subtitle_text = (
665+ f"Precision: { overall_metrics ['precision' ]:.2f} , "
666+ f"Recall: { overall_metrics ['recall' ]:.2f} , "
667+ f"F1: { overall_metrics ['f1_score' ]:.2f} , "
668+ f"Accuracy: { overall_metrics ['accuracy' ]:.2f} "
669+ )
670+
671+ plot = (
672+ p9 .ggplot (melted_df , p9 .aes (x = "metric" , y = "value" ))
673+ + p9 .geom_boxplot (outlier_shape = None , fill = "lightblue" , alpha = 0.7 )
674+ + p9 .geom_jitter (p9 .aes (color = "frame_total" ), width = 0.05 , height = 0 )
675+ + p9 .geom_text (
676+ p9 .aes (label = "video_name" ), data = outliers , ha = "left" , nudge_x = 0.1
677+ )
678+ + p9 .labs (
679+ title = "Frame-level Performance Metrics" ,
680+ y = "Score" ,
681+ x = "Metric" ,
682+ subtitle = subtitle_text ,
683+ )
684+ + p9 .theme_bw ()
685+ + p9 .theme (
686+ plot_title = p9 .element_text (ha = "center" ), # Center the main title
687+ plot_subtitle = p9 .element_text (ha = "center" ), # Center the subtitle too
688+ )
689+ )
690+
691+ return plot
0 commit comments