From d07621202460cac4748d60159a7043531a28e604 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 29 Oct 2025 11:14:16 -0400 Subject: [PATCH 1/8] generate frame-level performance plot --- src/jabs_postprocess/compare_gt.py | 168 ++++++++++++++++++++++++++++- 1 file changed, 167 insertions(+), 1 deletion(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index b2322e1..6777f14 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -8,6 +8,7 @@ import pandas as pd import plotnine as p9 + from jabs_postprocess.utils.project_utils import ( Bouts, BoutTable, @@ -18,7 +19,7 @@ logger = logging.getLogger(__name__) - +# %% def evaluate_ground_truth( behavior: str, ground_truth_folder: Path, @@ -122,6 +123,16 @@ def evaluate_ground_truth( pred_df["is_gt"] = False all_annotations = pd.concat([gt_df, pred_df]) + # Generate frame-level performance plot + framewise_plot = generate_framewise_performace_plot(gt_df, pred_df) + if ouput_paths["framewise_plot"] is not None: + framewise_plot.save( + ouput_paths["framewise_plot"], height=6, width=12, dpi=300 + ) + logging.info( + f"Frame-level performance plot saved to {ouput_paths['framewise_plot']}" + ) + # We only want the positive examples for performance evaluation # (but for ethogram plotting later, we'll use the full all_annotations) performance_annotations = all_annotations[ @@ -480,4 +491,159 @@ def generate_output_paths(results_folder: Path): "ethogram": results_folder / "ethogram.png", "scan_plot": results_folder / "scan_performance.png", "bout_plot": results_folder / "bout_performance.png", + "framewise_plot": results_folder / "framewise_performance.png", + } + +def _compute_framewise_confusion(gt_df, pred_df): + """Compute frame-level confusion counts (TP, TN, FP, FN) per video. + + Args: + gt_df (pd.DataFrame): Ground truth intervals with columns + ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior']. + pred_df (pd.DataFrame): Prediction intervals with the same structure. + + Returns: + pd.DataFrame: Confusion matrix counts per video with columns + ['video_name', 'TP', 'TN', 'FP', 'FN']. + """ + + def expand_intervals_to_frames(df): + """Expand behavior intervals into per-frame rows.""" + expanded = df.copy() + expanded['frame'] = expanded.apply( + lambda row: range(row['start'], row['start'] + row['duration']), + axis = 1 + ) + expanded = expanded.explode('frame') + expanded = expanded.sort_values(by=['animal_idx', 'frame']) + return expanded + + # Expand ground truth and predictions into frame-level data + gt_frames = expand_intervals_to_frames(gt_df) + pred_frames = expand_intervals_to_frames(pred_df) + + # Merge to align predictions and ground truth per frame + framewise = pd.merge( + gt_frames, pred_frames, + on=['video_name', 'animal_idx', 'frame'], + how='left', + suffixes=('_gt', '_pred') + ) + + # Compute confusion counts per video + confusion_counts = ( + framewise.groupby('video_name') + .apply(lambda x: pd.Series({ + 'TP': ((x['is_behavior_gt'] == 1) & (x['is_behavior_pred'] == 1)).sum(), + 'TN': ((x['is_behavior_gt'] == 0) & (x['is_behavior_pred'] == 0)).sum(), + 'FP': ((x['is_behavior_gt'] == 0) & (x['is_behavior_pred'] == 1)).sum(), + 'FN': ((x['is_behavior_gt'] == 1) & (x['is_behavior_pred'] == 0)).sum(), + })) + .reset_index() + ) + + return confusion_counts + +def _find_outliers(melted_df: pd.DataFrame): + # Identify outliers per metric using IQR rule + outliers = [] + for metric in melted_df['metric'].unique(): + values = melted_df.loc[melted_df['metric'] == metric, 'value'] + q1 = values.quantile(0.25) + q3 = values.quantile(0.75) + iqr = q3 - q1 + lower_bound = q1 - 1.5 * iqr + upper_bound = q3 + 1.5 * iqr + outliers_df = melted_df[ + (melted_df['metric'] == metric) + & ((melted_df['value'] < lower_bound) | (melted_df['value'] > upper_bound)) + ] + outliers.append(outliers_df) + + outliers = pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns) + + return outliers + +def generate_framewise_performace_plot( + gt_df: pd.DataFrame, pred_df:pd.DataFrame): + """ + Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals. + + This function: + 1. Expands each interval in `gt_df` and `pred_df` to per-frame annotations. + 2. Computes per-video confusion counts (TP, TN, FP, FN). + 3. Calculates precision, recall, F1 score, and accuracy for each video. + 4. Produces a boxplot with jitter showing the distribution of these metrics. + 5. Adds an overall summary in the plot subtitle. + + Args: + gt_df (pd.DataFrame): Ground truth intervals with columns + ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior']. + pred_df (pd.DataFrame): Prediction intervals with the same structure. + + Returns: + plotnine.ggplot: A ggplot object containing the frame-level performance visualization. + """ + # Compute framewise confusion counts + confusion_counts = _compute_framewise_confusion(gt_df, pred_df) + confusion_counts['frame_total'] = ( + confusion_counts['TP'] + confusion_counts['TN'] + + confusion_counts['FP'] + confusion_counts['FN'] + ) + + # Compute per-video metrics + confusion_counts['precision'] = confusion_counts['TP'] / (confusion_counts['TP'] + confusion_counts['FP']) + confusion_counts['recall'] = confusion_counts['TP'] / (confusion_counts['TP'] + confusion_counts['FN']) + confusion_counts['f1_score'] = 2 * (confusion_counts['precision'] * confusion_counts['recall']) / (confusion_counts['precision'] + confusion_counts['recall']) + confusion_counts['accuracy'] = (confusion_counts['TP'] + confusion_counts['TN']) / confusion_counts['frame_total'] + + # Compute overall (global) metrics + totals = confusion_counts[['TP', 'TN', 'FP', 'FN']].sum() + overall_metrics = { + 'precision': totals['TP'] / (totals['TP'] + totals['FP']), + 'recall': totals['TP'] / (totals['TP'] + totals['FN']), + 'accuracy': (totals['TP'] + totals['TN']) + / (totals['TP'] + totals['TN'] + totals['FP'] + totals['FN']) } + overall_metrics['f1_score'] = 2 * (overall_metrics['precision'] * overall_metrics['recall']) / (overall_metrics['precision'] + overall_metrics['recall']) + + # Melt into long format for plotting + melted_df = pd.melt( + confusion_counts, + id_vars=['video_name', 'frame_total'], + value_vars=['precision', 'recall', 'f1_score', 'accuracy'], + var_name='metric', + value_name='value') + + outliers = _find_outliers(melted_df) + # Generate plot + subtitle_text = ( + f"Precision: {overall_metrics['precision']:.2f}, " + f"Recall: {overall_metrics['recall']:.2f}, " + f"F1: {overall_metrics['f1_score']:.2f}, " + f"Accuracy: {overall_metrics['accuracy']:.2f}" + ) + + plot = ( + p9.ggplot(melted_df, p9.aes(x='metric', y='value')) + + p9.geom_boxplot(outlier_shape = None, fill='lightblue', alpha=0.7) + + p9.geom_jitter(p9.aes(color = 'frame_total'), width=0.05, height=0) + + p9.geom_text( + p9.aes(label='video_name'), + data=outliers, + ha='left', + nudge_x=0.1 + ) + + p9.labs( + title='Frame-level Performance Metrics', + y='Score', x='Metric', + subtitle=subtitle_text) + + p9.theme_bw() + + p9.theme( + plot_title=p9.element_text(ha='center'), # Center the main title + plot_subtitle=p9.element_text(ha='center') # Center the subtitle too + ) + ) + + return plot + From 328bf2cb431d08896b0ad7f44beb8e6d71e50aee Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 29 Oct 2025 13:14:13 -0400 Subject: [PATCH 2/8] fix documentation typo and add tests --- src/jabs_postprocess/compare_gt.py | 11 +++-- tests/test_compare_gt.py | 69 +++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 6777f14..45f780e 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -42,9 +42,12 @@ def evaluate_ground_truth( filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions) + trim_time: Limit the duration in frames of videos for performance + Returns: + None, but saves the following files to results_folder: + framewise_output: Output file to save the frame-level performance plot scan_output: Output file to save the filter scan performance plot bout_output: Output file to save the resulting bout performance plot - trim_time: Limit the duration in frames of videos for performance ethogram_output: Output file to save the ethogram plot comparing GT and predictions scan_csv_output: Output file to save the scan performance data as CSV """ @@ -124,7 +127,7 @@ def evaluate_ground_truth( all_annotations = pd.concat([gt_df, pred_df]) # Generate frame-level performance plot - framewise_plot = generate_framewise_performace_plot(gt_df, pred_df) + framewise_plot = generate_framewise_performance_plot(gt_df, pred_df) if ouput_paths["framewise_plot"] is not None: framewise_plot.save( ouput_paths["framewise_plot"], height=6, width=12, dpi=300 @@ -538,7 +541,7 @@ def expand_intervals_to_frames(df): 'TN': ((x['is_behavior_gt'] == 0) & (x['is_behavior_pred'] == 0)).sum(), 'FP': ((x['is_behavior_gt'] == 0) & (x['is_behavior_pred'] == 1)).sum(), 'FN': ((x['is_behavior_gt'] == 1) & (x['is_behavior_pred'] == 0)).sum(), - })) + }), include_groups=False) .reset_index() ) @@ -564,7 +567,7 @@ def _find_outliers(melted_df: pd.DataFrame): return outliers -def generate_framewise_performace_plot( +def generate_framewise_performance_plot( gt_df: pd.DataFrame, pred_df:pd.DataFrame): """ Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals. diff --git a/tests/test_compare_gt.py b/tests/test_compare_gt.py index 56fdbfc..5becd39 100644 --- a/tests/test_compare_gt.py +++ b/tests/test_compare_gt.py @@ -27,14 +27,15 @@ test multiple scenarios efficiently. Fixtures are provided for common test data like mock bout tables, JABS projects, and annotation samples. """ - +# %% from unittest.mock import MagicMock, patch import numpy as np import pandas as pd +import plotnine as p9 import pytest -from jabs_postprocess.compare_gt import evaluate_ground_truth, generate_iou_scan +from jabs_postprocess.compare_gt import evaluate_ground_truth, generate_iou_scan, generate_framewise_performance_plot from jabs_postprocess.utils.project_utils import ( Bouts, ) @@ -672,3 +673,67 @@ def test_generate_iou_scan_metrics_calculation(mock_metrics, expected_result): assert np.isnan(row[metric]) else: assert round(row[metric], 3) == round(expected, 3) + +@pytest.fixture +def sample_data(): + """Create small sample GT and prediction DataFrames for testing.""" + gt_df = pd.DataFrame({ + 'video_name': ['video1', 'video1', 'video2'], + 'animal_idx': [0, 1, 0], + 'start': [0, 5, 0], + 'duration': [5, 5, 10], + 'is_behavior': [1, 0, 1] + }) + + pred_df = pd.DataFrame({ + 'video_name': ['video1', 'video1', 'video2'], + 'animal_idx': [0, 1, 0], + 'start': [0, 5, 0], + 'duration': [5, 5, 10], + 'is_behavior': [1, 0, 0] + }) + + return gt_df, pred_df + + +def test_generate_plot_runs(sample_data): + """Test that the plot function runs and returns a ggplot object.""" + gt_df, pred_df = sample_data + plot = generate_framewise_performance_plot(gt_df, pred_df) + # Check that the returned object is a ggplot + assert isinstance(plot, p9.ggplot) + +def test_plot_metrics(sample_data): + """Test that generate_framewise_performance_plot correctly handles NaNs.""" + gt_df, pred_df = sample_data + + plot = generate_framewise_performance_plot(gt_df, pred_df) + df = plot.data.sort_values(['video_name','metric']).reset_index(drop=True) + + # Manually compute expected metrics + expected = [] + # Video 1: Perfect prediction + expected.append({'video_name': 'video1', + 'precision': 1.0, 'recall': 1.0, + 'f1_score': 1.0, 'accuracy': 1.0}) + # Video 2: All wrong + expected.append({'video_name':'video2', + 'precision': float('nan'), 'recall': 0.0, + 'f1_score': float('nan'), 'accuracy':0.0}) + + expected_df = pd.DataFrame(expected) + expected_melted = pd.melt( + expected_df, + id_vars=['video_name'], + value_vars=['precision','recall','f1_score','accuracy'], + var_name='metric', + value_name='value' + ).sort_values(['video_name','metric']).reset_index(drop=True) + + # Compare numeric values, treating NaNs as equal + for a, b in zip(df['value'], expected_melted['value']): + if pd.isna(a) and pd.isna(b): + continue + else: + assert abs(a - b) < 1e-6 +# %% From 0fbdd9a74d277e74d137feeba9e7fbdc7fe1809b Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 29 Oct 2025 13:16:18 -0400 Subject: [PATCH 3/8] remove #%% markers --- src/jabs_postprocess/compare_gt.py | 3 +-- tests/test_compare_gt.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 45f780e..1343d38 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -8,7 +8,6 @@ import pandas as pd import plotnine as p9 - from jabs_postprocess.utils.project_utils import ( Bouts, BoutTable, @@ -19,7 +18,7 @@ logger = logging.getLogger(__name__) -# %% + def evaluate_ground_truth( behavior: str, ground_truth_folder: Path, diff --git a/tests/test_compare_gt.py b/tests/test_compare_gt.py index 5becd39..42f91cd 100644 --- a/tests/test_compare_gt.py +++ b/tests/test_compare_gt.py @@ -27,7 +27,7 @@ test multiple scenarios efficiently. Fixtures are provided for common test data like mock bout tables, JABS projects, and annotation samples. """ -# %% + from unittest.mock import MagicMock, patch import numpy as np From 1858684c335b7756ea7abe0b0effe2347d29a817 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 29 Oct 2025 13:23:51 -0400 Subject: [PATCH 4/8] lint files --- src/jabs_postprocess/compare_gt.py | 166 +++++++++++++++++------------ tests/test_compare_gt.py | 100 ++++++++++------- 2 files changed, 162 insertions(+), 104 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 1343d38..ad172ea 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -41,7 +41,7 @@ def evaluate_ground_truth( filter_scan: List of filter (minimum duration in frames to consider real) values to test iou_thresholds: List of intersection over union thresholds to scan filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions) - trim_time: Limit the duration in frames of videos for performance + trim_time: Limit the duration in frames of videos for performance Returns: None, but saves the following files to results_folder: framewise_output: Output file to save the frame-level performance plot @@ -128,9 +128,7 @@ def evaluate_ground_truth( # Generate frame-level performance plot framewise_plot = generate_framewise_performance_plot(gt_df, pred_df) if ouput_paths["framewise_plot"] is not None: - framewise_plot.save( - ouput_paths["framewise_plot"], height=6, width=12, dpi=300 - ) + framewise_plot.save(ouput_paths["framewise_plot"], height=6, width=12, dpi=300) logging.info( f"Frame-level performance plot saved to {ouput_paths['framewise_plot']}" ) @@ -496,78 +494,95 @@ def generate_output_paths(results_folder: Path): "framewise_plot": results_folder / "framewise_performance.png", } + def _compute_framewise_confusion(gt_df, pred_df): """Compute frame-level confusion counts (TP, TN, FP, FN) per video. Args: - gt_df (pd.DataFrame): Ground truth intervals with columns + gt_df (pd.DataFrame): Ground truth intervals with columns ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior']. pred_df (pd.DataFrame): Prediction intervals with the same structure. Returns: - pd.DataFrame: Confusion matrix counts per video with columns + pd.DataFrame: Confusion matrix counts per video with columns ['video_name', 'TP', 'TN', 'FP', 'FN']. """ def expand_intervals_to_frames(df): """Expand behavior intervals into per-frame rows.""" expanded = df.copy() - expanded['frame'] = expanded.apply( - lambda row: range(row['start'], row['start'] + row['duration']), - axis = 1 + expanded["frame"] = expanded.apply( + lambda row: range(row["start"], row["start"] + row["duration"]), axis=1 ) - expanded = expanded.explode('frame') - expanded = expanded.sort_values(by=['animal_idx', 'frame']) + expanded = expanded.explode("frame") + expanded = expanded.sort_values(by=["animal_idx", "frame"]) return expanded - + # Expand ground truth and predictions into frame-level data gt_frames = expand_intervals_to_frames(gt_df) pred_frames = expand_intervals_to_frames(pred_df) # Merge to align predictions and ground truth per frame framewise = pd.merge( - gt_frames, pred_frames, - on=['video_name', 'animal_idx', 'frame'], - how='left', - suffixes=('_gt', '_pred') + gt_frames, + pred_frames, + on=["video_name", "animal_idx", "frame"], + how="left", + suffixes=("_gt", "_pred"), ) - + # Compute confusion counts per video confusion_counts = ( - framewise.groupby('video_name') - .apply(lambda x: pd.Series({ - 'TP': ((x['is_behavior_gt'] == 1) & (x['is_behavior_pred'] == 1)).sum(), - 'TN': ((x['is_behavior_gt'] == 0) & (x['is_behavior_pred'] == 0)).sum(), - 'FP': ((x['is_behavior_gt'] == 0) & (x['is_behavior_pred'] == 1)).sum(), - 'FN': ((x['is_behavior_gt'] == 1) & (x['is_behavior_pred'] == 0)).sum(), - }), include_groups=False) + framewise.groupby("video_name") + .apply( + lambda x: pd.Series( + { + "TP": ( + (x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 1) + ).sum(), + "TN": ( + (x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 0) + ).sum(), + "FP": ( + (x["is_behavior_gt"] == 0) & (x["is_behavior_pred"] == 1) + ).sum(), + "FN": ( + (x["is_behavior_gt"] == 1) & (x["is_behavior_pred"] == 0) + ).sum(), + } + ), + include_groups=False, + ) .reset_index() ) return confusion_counts + def _find_outliers(melted_df: pd.DataFrame): # Identify outliers per metric using IQR rule outliers = [] - for metric in melted_df['metric'].unique(): - values = melted_df.loc[melted_df['metric'] == metric, 'value'] + for metric in melted_df["metric"].unique(): + values = melted_df.loc[melted_df["metric"] == metric, "value"] q1 = values.quantile(0.25) q3 = values.quantile(0.75) iqr = q3 - q1 lower_bound = q1 - 1.5 * iqr upper_bound = q3 + 1.5 * iqr outliers_df = melted_df[ - (melted_df['metric'] == metric) - & ((melted_df['value'] < lower_bound) | (melted_df['value'] > upper_bound)) + (melted_df["metric"] == metric) + & ((melted_df["value"] < lower_bound) | (melted_df["value"] > upper_bound)) ] outliers.append(outliers_df) - outliers = pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns) - + outliers = ( + pd.concat(outliers) if outliers else pd.DataFrame(columns=melted_df.columns) + ) + return outliers -def generate_framewise_performance_plot( - gt_df: pd.DataFrame, pred_df:pd.DataFrame): + +def generate_framewise_performance_plot(gt_df: pd.DataFrame, pred_df: pd.DataFrame): """ Generate and save a frame-level performance plot comparing ground truth and predicted behavior intervals. @@ -579,44 +594,61 @@ def generate_framewise_performance_plot( 5. Adds an overall summary in the plot subtitle. Args: - gt_df (pd.DataFrame): Ground truth intervals with columns + gt_df (pd.DataFrame): Ground truth intervals with columns ['video_name', 'animal_idx', 'start', 'duration', 'is_behavior']. pred_df (pd.DataFrame): Prediction intervals with the same structure. - + Returns: plotnine.ggplot: A ggplot object containing the frame-level performance visualization. """ # Compute framewise confusion counts confusion_counts = _compute_framewise_confusion(gt_df, pred_df) - confusion_counts['frame_total'] = ( - confusion_counts['TP'] + confusion_counts['TN'] + - confusion_counts['FP'] + confusion_counts['FN'] + confusion_counts["frame_total"] = ( + confusion_counts["TP"] + + confusion_counts["TN"] + + confusion_counts["FP"] + + confusion_counts["FN"] ) - + # Compute per-video metrics - confusion_counts['precision'] = confusion_counts['TP'] / (confusion_counts['TP'] + confusion_counts['FP']) - confusion_counts['recall'] = confusion_counts['TP'] / (confusion_counts['TP'] + confusion_counts['FN']) - confusion_counts['f1_score'] = 2 * (confusion_counts['precision'] * confusion_counts['recall']) / (confusion_counts['precision'] + confusion_counts['recall']) - confusion_counts['accuracy'] = (confusion_counts['TP'] + confusion_counts['TN']) / confusion_counts['frame_total'] + confusion_counts["precision"] = confusion_counts["TP"] / ( + confusion_counts["TP"] + confusion_counts["FP"] + ) + confusion_counts["recall"] = confusion_counts["TP"] / ( + confusion_counts["TP"] + confusion_counts["FN"] + ) + confusion_counts["f1_score"] = ( + 2 + * (confusion_counts["precision"] * confusion_counts["recall"]) + / (confusion_counts["precision"] + confusion_counts["recall"]) + ) + confusion_counts["accuracy"] = ( + confusion_counts["TP"] + confusion_counts["TN"] + ) / confusion_counts["frame_total"] # Compute overall (global) metrics - totals = confusion_counts[['TP', 'TN', 'FP', 'FN']].sum() + totals = confusion_counts[["TP", "TN", "FP", "FN"]].sum() overall_metrics = { - 'precision': totals['TP'] / (totals['TP'] + totals['FP']), - 'recall': totals['TP'] / (totals['TP'] + totals['FN']), - 'accuracy': (totals['TP'] + totals['TN']) - / (totals['TP'] + totals['TN'] + totals['FP'] + totals['FN']) + "precision": totals["TP"] / (totals["TP"] + totals["FP"]), + "recall": totals["TP"] / (totals["TP"] + totals["FN"]), + "accuracy": (totals["TP"] + totals["TN"]) + / (totals["TP"] + totals["TN"] + totals["FP"] + totals["FN"]), } - overall_metrics['f1_score'] = 2 * (overall_metrics['precision'] * overall_metrics['recall']) / (overall_metrics['precision'] + overall_metrics['recall']) - + overall_metrics["f1_score"] = ( + 2 + * (overall_metrics["precision"] * overall_metrics["recall"]) + / (overall_metrics["precision"] + overall_metrics["recall"]) + ) + # Melt into long format for plotting melted_df = pd.melt( - confusion_counts, - id_vars=['video_name', 'frame_total'], - value_vars=['precision', 'recall', 'f1_score', 'accuracy'], - var_name='metric', - value_name='value') - + confusion_counts, + id_vars=["video_name", "frame_total"], + value_vars=["precision", "recall", "f1_score", "accuracy"], + var_name="metric", + value_name="value", + ) + outliers = _find_outliers(melted_df) # Generate plot subtitle_text = ( @@ -627,25 +659,23 @@ def generate_framewise_performance_plot( ) plot = ( - p9.ggplot(melted_df, p9.aes(x='metric', y='value')) - + p9.geom_boxplot(outlier_shape = None, fill='lightblue', alpha=0.7) - + p9.geom_jitter(p9.aes(color = 'frame_total'), width=0.05, height=0) + p9.ggplot(melted_df, p9.aes(x="metric", y="value")) + + p9.geom_boxplot(outlier_shape=None, fill="lightblue", alpha=0.7) + + p9.geom_jitter(p9.aes(color="frame_total"), width=0.05, height=0) + p9.geom_text( - p9.aes(label='video_name'), - data=outliers, - ha='left', - nudge_x=0.1 + p9.aes(label="video_name"), data=outliers, ha="left", nudge_x=0.1 ) + p9.labs( - title='Frame-level Performance Metrics', - y='Score', x='Metric', - subtitle=subtitle_text) + title="Frame-level Performance Metrics", + y="Score", + x="Metric", + subtitle=subtitle_text, + ) + p9.theme_bw() + p9.theme( - plot_title=p9.element_text(ha='center'), # Center the main title - plot_subtitle=p9.element_text(ha='center') # Center the subtitle too + plot_title=p9.element_text(ha="center"), # Center the main title + plot_subtitle=p9.element_text(ha="center"), # Center the subtitle too ) ) return plot - diff --git a/tests/test_compare_gt.py b/tests/test_compare_gt.py index 42f91cd..c51328d 100644 --- a/tests/test_compare_gt.py +++ b/tests/test_compare_gt.py @@ -35,7 +35,11 @@ import plotnine as p9 import pytest -from jabs_postprocess.compare_gt import evaluate_ground_truth, generate_iou_scan, generate_framewise_performance_plot +from jabs_postprocess.compare_gt import ( + evaluate_ground_truth, + generate_iou_scan, + generate_framewise_performance_plot, +) from jabs_postprocess.utils.project_utils import ( Bouts, ) @@ -674,25 +678,30 @@ def test_generate_iou_scan_metrics_calculation(mock_metrics, expected_result): else: assert round(row[metric], 3) == round(expected, 3) + @pytest.fixture def sample_data(): """Create small sample GT and prediction DataFrames for testing.""" - gt_df = pd.DataFrame({ - 'video_name': ['video1', 'video1', 'video2'], - 'animal_idx': [0, 1, 0], - 'start': [0, 5, 0], - 'duration': [5, 5, 10], - 'is_behavior': [1, 0, 1] - }) - - pred_df = pd.DataFrame({ - 'video_name': ['video1', 'video1', 'video2'], - 'animal_idx': [0, 1, 0], - 'start': [0, 5, 0], - 'duration': [5, 5, 10], - 'is_behavior': [1, 0, 0] - }) - + gt_df = pd.DataFrame( + { + "video_name": ["video1", "video1", "video2"], + "animal_idx": [0, 1, 0], + "start": [0, 5, 0], + "duration": [5, 5, 10], + "is_behavior": [1, 0, 1], + } + ) + + pred_df = pd.DataFrame( + { + "video_name": ["video1", "video1", "video2"], + "animal_idx": [0, 1, 0], + "start": [0, 5, 0], + "duration": [5, 5, 10], + "is_behavior": [1, 0, 0], + } + ) + return gt_df, pred_df @@ -703,37 +712,56 @@ def test_generate_plot_runs(sample_data): # Check that the returned object is a ggplot assert isinstance(plot, p9.ggplot) + def test_plot_metrics(sample_data): """Test that generate_framewise_performance_plot correctly handles NaNs.""" gt_df, pred_df = sample_data - + plot = generate_framewise_performance_plot(gt_df, pred_df) - df = plot.data.sort_values(['video_name','metric']).reset_index(drop=True) - + df = plot.data.sort_values(["video_name", "metric"]).reset_index(drop=True) + # Manually compute expected metrics expected = [] # Video 1: Perfect prediction - expected.append({'video_name': 'video1', - 'precision': 1.0, 'recall': 1.0, - 'f1_score': 1.0, 'accuracy': 1.0}) + expected.append( + { + "video_name": "video1", + "precision": 1.0, + "recall": 1.0, + "f1_score": 1.0, + "accuracy": 1.0, + } + ) # Video 2: All wrong - expected.append({'video_name':'video2', - 'precision': float('nan'), 'recall': 0.0, - 'f1_score': float('nan'), 'accuracy':0.0}) - + expected.append( + { + "video_name": "video2", + "precision": float("nan"), + "recall": 0.0, + "f1_score": float("nan"), + "accuracy": 0.0, + } + ) + expected_df = pd.DataFrame(expected) - expected_melted = pd.melt( - expected_df, - id_vars=['video_name'], - value_vars=['precision','recall','f1_score','accuracy'], - var_name='metric', - value_name='value' - ).sort_values(['video_name','metric']).reset_index(drop=True) - + expected_melted = ( + pd.melt( + expected_df, + id_vars=["video_name"], + value_vars=["precision", "recall", "f1_score", "accuracy"], + var_name="metric", + value_name="value", + ) + .sort_values(["video_name", "metric"]) + .reset_index(drop=True) + ) + # Compare numeric values, treating NaNs as equal - for a, b in zip(df['value'], expected_melted['value']): + for a, b in zip(df["value"], expected_melted["value"]): if pd.isna(a) and pd.isna(b): continue else: assert abs(a - b) < 1e-6 + + # %% From 749d558af953186c0f8ee227e1b204f8126ef911 Mon Sep 17 00:00:00 2001 From: Alexander Berger Date: Fri, 31 Oct 2025 16:07:01 -0400 Subject: [PATCH 5/8] Fix typeo in --- src/jabs_postprocess/compare_gt.py | 36 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index ad172ea..9b724cc 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -50,7 +50,7 @@ def evaluate_ground_truth( ethogram_output: Output file to save the ethogram plot comparing GT and predictions scan_csv_output: Output file to save the scan performance data as CSV """ - ouput_paths = generate_output_paths(results_folder) + output_paths = generate_output_paths(results_folder) # Set default values if not provided stitch_scan = stitch_scan or np.arange(5, 46, 5).tolist() @@ -127,10 +127,10 @@ def evaluate_ground_truth( # Generate frame-level performance plot framewise_plot = generate_framewise_performance_plot(gt_df, pred_df) - if ouput_paths["framewise_plot"] is not None: - framewise_plot.save(ouput_paths["framewise_plot"], height=6, width=12, dpi=300) + if output_paths["framewise_plot"] is not None: + framewise_plot.save(output_paths["framewise_plot"], height=6, width=12, dpi=300) logging.info( - f"Frame-level performance plot saved to {ouput_paths['framewise_plot']}" + f"Frame-level performance plot saved to {output_paths['framewise_plot']}" ) # We only want the positive examples for performance evaluation @@ -161,9 +161,9 @@ def evaluate_ground_truth( "No performance data to analyze. Ensure that the ground truth and predictions are correctly formatted and contain valid bouts." ) - if ouput_paths["scan_csv"] is not None: - performance_df.to_csv(ouput_paths["scan_csv"], index=False) - logging.info(f"Scan performance data saved to {ouput_paths['scan_csv']}") + if output_paths["scan_csv"] is not None: + performance_df.to_csv(output_paths["scan_csv"], index=False) + logging.info(f"Scan performance data saved to {output_paths['scan_csv']}") _melted_df = pd.melt(performance_df, id_vars=["threshold", "stitch", "filter"]) @@ -183,8 +183,8 @@ def evaluate_ground_truth( + p9.theme_bw() + p9.labs(title=f"No performance data for {middle_threshold} IoU") ) - if ouput_paths["scan_plot"]: - plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300) + if output_paths["scan_plot"]: + plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300) # Create default winning filters with first values from scan parameters winning_filters = pd.DataFrame( { @@ -242,8 +242,8 @@ def evaluate_ground_truth( + p9.scale_fill_continuous(na_value=0) ) - if ouput_paths["scan_plot"]: - plot.save(ouput_paths["scan_plot"], height=6, width=12, dpi=300) + if output_paths["scan_plot"]: + plot.save(output_paths["scan_plot"], height=6, width=12, dpi=300) # Handle case where all f1_plot values are NaN or empty if subset_df["f1_plot"].isna().all() or len(subset_df) == 0: @@ -265,9 +265,9 @@ def evaluate_ground_truth( ).T.reset_index(drop=True)[["stitch", "filter"]] winning_bout_df = pd.merge(performance_df, winning_filters, on=["stitch", "filter"]) - if ouput_paths["bout_csv"] is not None: - winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False) - logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}") + if output_paths["bout_csv"] is not None: + winning_bout_df.to_csv(output_paths["bout_csv"], index=False) + logging.info(f"Bout performance data saved to {output_paths['bout_csv']}") melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"]) @@ -279,9 +279,9 @@ def evaluate_ground_truth( + p9.geom_line() + p9.theme_bw() + p9.scale_y_continuous(limits=(0, 1)) - ).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300) + ).save(output_paths["bout_plot"], height=6, width=12, dpi=300) - if ouput_paths["ethogram"] is not None: + if output_paths["ethogram"] is not None: # Prepare data for ethogram plot # Use all_annotations to include both behavior (1) and not-behavior (0) states plot_df = all_annotations.copy() @@ -338,14 +338,14 @@ def evaluate_ground_truth( ) # Adjust height based on the number of unique animal-video combinations ethogram_plot.save( - ouput_paths["ethogram"], + output_paths["ethogram"], height=1.5 * num_unique_combos + 2, width=12, dpi=300, limitsize=False, verbose=False, ) - logging.info(f"Ethogram plot saved to {ouput_paths['ethogram']}") + logging.info(f"Ethogram plot saved to {output_paths['ethogram']}") else: logger.warning( f"No behavior instances found for behavior {behavior} after filtering for ethogram." From 77cd62a0f59b82677b40d98b6909727a731dd265 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Mon, 10 Nov 2025 10:58:46 -0500 Subject: [PATCH 6/8] address PR cmts --- src/jabs_postprocess/compare_gt.py | 34 ++++++++++++++++++------------ tests/test_compare_gt.py | 1 - 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index 9b724cc..fe7c999 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -494,6 +494,15 @@ def generate_output_paths(results_folder: Path): "framewise_plot": results_folder / "framewise_performance.png", } +def _expand_intervals_to_frames(df): + """Expand behavior intervals into per-frame rows.""" + expanded = df.copy() + expanded["frame"] = expanded.apply( + lambda row: range(row["start"], row["start"] + row["duration"]), axis=1 + ) + expanded = expanded.explode("frame") + expanded = expanded.sort_values(by=["animal_idx", "frame"]) + return expanded def _compute_framewise_confusion(gt_df, pred_df): """Compute frame-level confusion counts (TP, TN, FP, FN) per video. @@ -508,19 +517,9 @@ def _compute_framewise_confusion(gt_df, pred_df): ['video_name', 'TP', 'TN', 'FP', 'FN']. """ - def expand_intervals_to_frames(df): - """Expand behavior intervals into per-frame rows.""" - expanded = df.copy() - expanded["frame"] = expanded.apply( - lambda row: range(row["start"], row["start"] + row["duration"]), axis=1 - ) - expanded = expanded.explode("frame") - expanded = expanded.sort_values(by=["animal_idx", "frame"]) - return expanded - # Expand ground truth and predictions into frame-level data - gt_frames = expand_intervals_to_frames(gt_df) - pred_frames = expand_intervals_to_frames(pred_df) + gt_frames = _expand_intervals_to_frames(gt_df) + pred_frames = _expand_intervals_to_frames(pred_df) # Merge to align predictions and ground truth per frame framewise = pd.merge( @@ -560,7 +559,16 @@ def expand_intervals_to_frames(df): def _find_outliers(melted_df: pd.DataFrame): - # Identify outliers per metric using IQR rule + """ + Return rows flagged as outliers per metric using the IQR rule. + + Args: + melted_df: long-form DataFrame with at least 'metric' and 'value' columns. + + Returns: + DataFrame containing the outliers rows from the input DataFrame. + Returns an empty DataFrame with the same columns if no outliers found. + """ outliers = [] for metric in melted_df["metric"].unique(): values = melted_df.loc[melted_df["metric"] == metric, "value"] diff --git a/tests/test_compare_gt.py b/tests/test_compare_gt.py index c51328d..cb324b3 100644 --- a/tests/test_compare_gt.py +++ b/tests/test_compare_gt.py @@ -764,4 +764,3 @@ def test_plot_metrics(sample_data): assert abs(a - b) < 1e-6 -# %% From ad3630694d5ce83b4249a2f6169a2293b26fb190 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Mon, 10 Nov 2025 11:00:46 -0500 Subject: [PATCH 7/8] linting --- src/jabs_postprocess/compare_gt.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/jabs_postprocess/compare_gt.py b/src/jabs_postprocess/compare_gt.py index fe7c999..1dfef70 100644 --- a/src/jabs_postprocess/compare_gt.py +++ b/src/jabs_postprocess/compare_gt.py @@ -494,15 +494,17 @@ def generate_output_paths(results_folder: Path): "framewise_plot": results_folder / "framewise_performance.png", } + def _expand_intervals_to_frames(df): - """Expand behavior intervals into per-frame rows.""" - expanded = df.copy() - expanded["frame"] = expanded.apply( - lambda row: range(row["start"], row["start"] + row["duration"]), axis=1 - ) - expanded = expanded.explode("frame") - expanded = expanded.sort_values(by=["animal_idx", "frame"]) - return expanded + """Expand behavior intervals into per-frame rows.""" + expanded = df.copy() + expanded["frame"] = expanded.apply( + lambda row: range(row["start"], row["start"] + row["duration"]), axis=1 + ) + expanded = expanded.explode("frame") + expanded = expanded.sort_values(by=["animal_idx", "frame"]) + return expanded + def _compute_framewise_confusion(gt_df, pred_df): """Compute frame-level confusion counts (TP, TN, FP, FN) per video. From b46d865a7c64e3d2c7b211e0b7815bc3a390e56a Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Mon, 10 Nov 2025 11:01:53 -0500 Subject: [PATCH 8/8] lint --- tests/test_compare_gt.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_compare_gt.py b/tests/test_compare_gt.py index cb324b3..f92a3ac 100644 --- a/tests/test_compare_gt.py +++ b/tests/test_compare_gt.py @@ -762,5 +762,3 @@ def test_plot_metrics(sample_data): continue else: assert abs(a - b) < 1e-6 - -