Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions src/jabs_postprocess/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,22 @@ def evaluate_ground_truth(
),
filter_ground_truth: bool = typer.Option(
False,
help="Apply filters to ground truth data (default is only to filter predictions)",
help=(
"Enable extra filtered outputs and apply stitch/filter to BOTH GT and predictions. "
"Use together with --stitch-value-filter and --filter-value-filter."
),
),
stitch_value_filter: Optional[int] = typer.Option(
None,
"--stitch-value-filter",
"--stitch_value_filter",
help="Stitch (frames) to use for filtered outputs",
),
filter_value_filter: Optional[int] = typer.Option(
None,
"--filter-value-filter",
"--filter_value_filter",
help="Minimum bout (frames) to use for filtered outputs",
),
trim_time: Optional[int] = typer.Option(
None,
Expand All @@ -177,6 +192,18 @@ def evaluate_ground_truth(
f"Prediction folder does not exist: {prediction_folder}"
)

# Convert CLI options into the dict expected by the underlying function
filter_gt_dict: Optional[dict] = None
if filter_ground_truth:
if stitch_value_filter is None or filter_value_filter is None:
raise typer.BadParameter(
"When using --filter-ground-truth, you must also provide --stitch-value-filter and --filter-value-filter."
)
filter_gt_dict = {
"stitch": int(stitch_value_filter),
"filter": int(filter_value_filter),
}

# Call the refactored function with individual parameters
compare_gt.evaluate_ground_truth(
behavior=behavior,
Expand All @@ -186,7 +213,7 @@ def evaluate_ground_truth(
stitch_scan=stitch_scan,
filter_scan=filter_scan,
iou_thresholds=iou_thresholds,
filter_ground_truth=filter_ground_truth,
filter_ground_truth=filter_gt_dict,
trim_time=trim_time,
)

Expand Down
304 changes: 290 additions & 14 deletions src/jabs_postprocess/compare_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def evaluate_ground_truth(
stitch_scan: List[float] = None,
filter_scan: List[float] = None,
iou_thresholds: List[float] = None,
filter_ground_truth: bool = False,
filter_ground_truth: Optional[dict] = None,
trim_time: Optional[int] = None,
):
"""Main function for evaluating ground truth annotations against classifier predictions.
Expand All @@ -40,7 +40,11 @@ def evaluate_ground_truth(
stitch_scan: List of stitching (time gaps in frames to merge bouts together) values to test
filter_scan: List of filter (minimum duration in frames to consider real) values to test
iou_thresholds: List of intersection over union thresholds to scan
filter_ground_truth: Apply filters to ground truth data (default is only to filter predictions)
filter_ground_truth: Optional dict specifying stitch/filter to apply to BOTH GT and predictions
for additional filtered outputs. If provided, need to include two more arguments: stitch_value_filter
and filter_value_filter
stitch_value_filter: Stitch (frames) to use for filtered outputs (gt and pred)
filter_value_filter: Minimum bout (frames) to use for filtered outputs (gt and pred)
scan_output: Output file to save the filter scan performance plot
bout_output: Output file to save the resulting bout performance plot
trim_time: Limit the duration in frames of videos for performance
Expand Down Expand Up @@ -139,7 +143,7 @@ def evaluate_ground_truth(
stitch_scan,
filter_scan,
iou_thresholds,
filter_ground_truth,
False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkepticRaven, I'm curious about your opinion on permanently setting this to False. This is thefilter_ground_truth argument of the generate_iou_scan call.

It's used to switch on this call to filter_by_settings.

Which, for reference, is defined here.

Copy link
Contributor

@SkepticRaven SkepticRaven Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a machine learning perspective, we should avoid modifying the ground truth data. I'd prefer it wasn't a parameter at all. The reason this capability exists is just for practical reasons. While the user should go in and manually modify the ground truth data, it can get effort-expensive. Allowing the ground truth to be modify-able runs the risk of observing "improved" performance by removing shorter but difficult real events.

At least in this edit, the behavior appears to be generating an unfiltered and filtered version (edits below, L345).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does it make sense to remove the filter_ground_truth argument entirely?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should first look into what generate_filtered_iou_curve does differently. That new function appears to have high overlap with what that arg does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd make more sense for me to change this back to filter_ground_truth and modify my code to work with some of the older logic. In the current state, I rewrote the whole function without removing the older functionality.

)

if performance_df.empty:
Expand Down Expand Up @@ -258,17 +262,7 @@ def evaluate_ground_truth(
winning_bout_df.to_csv(ouput_paths["bout_csv"], index=False)
logging.info(f"Bout performance data saved to {ouput_paths['bout_csv']}")

melted_winning = pd.melt(winning_bout_df, id_vars=["threshold", "stitch", "filter"])

(
p9.ggplot(
melted_winning[melted_winning["variable"].isin(["pr", "re", "f1"])],
p9.aes(x="threshold", y="value", color="variable"),
)
+ p9.geom_line()
+ p9.theme_bw()
+ p9.scale_y_continuous(limits=(0, 1))
).save(ouput_paths["bout_plot"], height=6, width=12, dpi=300)
_save_bout_curve_performance(winning_bout_df, ouput_paths["bout_plot"])

if ouput_paths["ethogram"] is not None:
# Prepare data for ethogram plot
Expand Down Expand Up @@ -348,6 +342,52 @@ def evaluate_ground_truth(
f"No annotations found for behavior {behavior} to generate ethogram plot."
)

# New: filtered outputs if user provided filter settings to apply to both GT and Pred
if filter_ground_truth is not None:
if not isinstance(filter_ground_truth, dict) or not {
"stitch",
"filter",
}.issubset(set(filter_ground_truth.keys())):
logger.warning(
"filter_ground_truth must be a dict with keys {'stitch','filter'}. Skipping filtered outputs."
)
else:
stitch_val = int(filter_ground_truth["stitch"])
filter_val = int(filter_ground_truth["filter"])
# 1) Filtered ethogram with 4 tracks
_save_filtered_ethogram(
all_annotations,
behavior,
stitch_val,
filter_val,
ouput_paths.get("ethogram_filtered"),
)
# 2) Filtered curves CSV + plot over IoU thresholds
filtered_curve_df = generate_filtered_iou_curve(
all_annotations,
stitch_val,
filter_val,
np.round(iou_thresholds, 2),
)
if filtered_curve_df is not None and len(filtered_curve_df) > 0:
if ouput_paths.get("bout_filtered_csv") is not None:
filtered_curve_df.to_csv(
ouput_paths["bout_filtered_csv"], index=False
)
logging.info(
f"Filtered curve performance saved to {ouput_paths['bout_filtered_csv']}"
)
# Reuse the same curve plotting by adding fixed stitch/filter columns
filtered_curve_df["stitch"] = stitch_val
filtered_curve_df["filter"] = filter_val
_save_bout_curve_performance(
filtered_curve_df, ouput_paths.get("bout_filtered_plot")
)
else:
logger.warning(
"No filtered performance data available to save plots/CSV."
)


def generate_iou_scan(
all_annotations,
Expand Down Expand Up @@ -480,4 +520,240 @@ def generate_output_paths(results_folder: Path):
"ethogram": results_folder / "ethogram.png",
"scan_plot": results_folder / "scan_performance.png",
"bout_plot": results_folder / "bout_performance.png",
# New filtered outputs
"ethogram_filtered": results_folder / "ethogram_filtered.png",
"bout_filtered_plot": results_folder / "bout_performance_filtered.png",
"bout_filtered_csv": results_folder / "bout_performance_filtered.csv",
}


def _save_bout_curve_performance(curve_df: pd.DataFrame, output_path: Optional[Path]):
"""
Saves the curve iou performance plot.

Args:
curve_df: Contains the curve performance data
output_path: Path to save the plot to
"""
if output_path is None:
return
melted_df = pd.melt(curve_df, id_vars=["threshold", "stitch", "filter"])
(
p9.ggplot(
melted_df[melted_df["variable"].isin(["pr", "re", "f1"])],
p9.aes(x="threshold", y="value", color="variable"),
)
+ p9.geom_line()
+ p9.theme_bw()
+ p9.scale_y_continuous(limits=(0, 1))
).save(output_path, height=6, width=12, dpi=300)


def _save_filtered_ethogram(
all_annotations: pd.DataFrame,
behavior: str,
stitch_val: int,
filter_val: int,
output_path: Optional[Path],
):
"""
Saves the filtered ethogram plot.

Args:
all_annotations: Contains the ethogram data
behavior: The behavior to plot
stitch_val: The stitch value
filter_val: The filter value
output_path: Path to save the plot to
"""

if output_path is None:
return

# Grabbing raw and filtered bouts per animal/video for gt and pred
records = []
for (cur_animal, cur_video), animal_df in all_annotations.groupby(
["animal_idx", "video_name"]
):
pr_df = animal_df[~animal_df["is_gt"]]
gt_df = animal_df[animal_df["is_gt"]]
if len(pr_df) == 0:
continue
pr_obj = Bouts(pr_df["start"], pr_df["duration"], pr_df["is_behavior"])
gt_obj = Bouts(gt_df["start"], gt_df["duration"], gt_df["is_behavior"])

full_duration = int(pr_obj.starts[-1] + pr_obj.durations[-1])
pr_obj.fill_to_size(full_duration, 0)
gt_obj.fill_to_size(full_duration, 0)

settings = ClassifierSettings(
"", interpolate=0, stitch=stitch_val, min_bout=filter_val
)
pr_fil = pr_obj.copy()
gt_fil = gt_obj.copy()
pr_fil.filter_by_settings(settings)
gt_fil.filter_by_settings(settings)

# Helper to extend records from a Bouts object for is_behavior == 1
def add_records_from_bouts(bouts_obj: Bouts, track_label: str):
starts = bouts_obj.starts
durations = bouts_obj.durations
values = bouts_obj.values
if starts is None or len(starts) == 0:
return
ends = starts + durations
for s, e, v in zip(starts, ends, values):
if v == 1:
records.append(
{
"animal_idx": cur_animal,
"video_name": cur_video,
"start": int(s),
"end": int(e),
"track": track_label,
}
)

add_records_from_bouts(gt_obj, "GT Raw")
add_records_from_bouts(gt_fil, "GT Filtered")
add_records_from_bouts(pr_obj, "Pred Raw")
add_records_from_bouts(pr_fil, "Pred Filtered")

if len(records) == 0:
logger.warning("No behavior bouts found to generate filtered ethogram.")
return

df = pd.DataFrame.from_records(records)
df["animal_video_combo"] = (
df["animal_idx"].astype(str) + " | " + df["video_name"].astype(str)
)

# Map track to vertical bands in the requested order: raw gt, filtered gt, raw pred, filtered pred
track_order = ["GT Raw", "GT Filtered", "Pred Raw", "Pred Filtered"]
track_to_idx = {
label: idx for idx, label in enumerate(track_order[::-1])
} # reverse so top is GT Raw
df["track_idx"] = df["track"].map(track_to_idx)
df["ymin"] = df["track_idx"].astype(float)
df["ymax"] = df["ymin"] + 0.9

num_unique_combos = len(df["animal_video_combo"].unique())

plot = (
p9.ggplot(df)
+ p9.geom_rect(
p9.aes(xmin="start", xmax="end", ymin="ymin", ymax="ymax", fill="track")
)
+ p9.theme_bw()
+ p9.facet_wrap("~animal_video_combo", ncol=1, scales="free_x")
+ p9.scale_y_continuous(
breaks=[track_to_idx[t] + 0.45 for t in track_order[::-1]],
labels=track_order[::-1],
name="",
)
+ p9.scale_fill_brewer(type="qual", palette="Set1")
+ p9.labs(
x="Frame",
fill="Track",
title=f"Ethogram (filtered) for behavior: {behavior}",
)
+ p9.expand_limits(x=0)
)

plot.save(
output_path,
height=2.0 * num_unique_combos + 2,
width=12,
dpi=300,
limitsize=False,
verbose=False,
)
logging.info(f"Filtered ethogram plot saved to {output_path}")


def generate_filtered_iou_curve(
all_annotations: pd.DataFrame,
stitch_val: int,
filter_val: int,
threshold_scan: np.ndarray,
) -> Optional[pd.DataFrame]:
"""Compute PR/RE/F1 across IoU thresholds after applying a fixed stitch/filter to BOTH GT and predictions.

Returns a DataFrame aggregated over animals/videos with columns: threshold,tp,fn,fp,pr,re,f1
"""
threshold_scan = np.round(threshold_scan, 2)
settings = ClassifierSettings(
"", interpolate=0, stitch=stitch_val, min_bout=filter_val
)

perf_rows = []
for (cur_animal, cur_video), animal_df in all_annotations.groupby(
["animal_idx", "video_name"]
):
pr_df = animal_df[~animal_df["is_gt"]]
if len(pr_df) == 0:
continue
gt_df = animal_df[animal_df["is_gt"]]
pr_obj = Bouts(pr_df["start"], pr_df["duration"], pr_df["is_behavior"])
gt_obj = Bouts(gt_df["start"], gt_df["duration"], gt_df["is_behavior"])

full_duration = int(pr_obj.starts[-1] + pr_obj.durations[-1])
pr_obj.fill_to_size(full_duration, 0)
gt_obj.fill_to_size(full_duration, 0)

pr_fil = pr_obj.copy()
pr_fil.filter_by_settings(settings)
gt_fil = gt_obj.copy()
gt_fil.filter_by_settings(settings)

# Handle empty-positive cases without calling compare_to (which expects non-empty arrays)
num_pr_pos = int(np.sum(pr_fil.values == 1))
num_gt_pos = int(np.sum(gt_fil.values == 1))
if num_pr_pos == 0 or num_gt_pos == 0:
for thr in threshold_scan:
if num_pr_pos == 0 and num_gt_pos == 0:
metrics = {"tp": 0, "fn": 0, "fp": 0, "pr": 0, "re": 0, "f1": 0}
elif num_pr_pos == 0 and num_gt_pos > 0:
metrics = {
"tp": 0,
"fn": num_gt_pos,
"fp": 0,
"pr": 0,
"re": 0,
"f1": 0,
}
else: # num_pr_pos > 0 and num_gt_pos == 0
metrics = {
"tp": 0,
"fn": 0,
"fp": num_pr_pos,
"pr": 0,
"re": 0,
"f1": 0,
}
perf_rows.append(
{
"animal": cur_animal,
"video": cur_video,
"threshold": thr,
**metrics,
}
)
continue

int_mat, u_mat, iou_mat = gt_fil.compare_to(pr_fil)
for thr in threshold_scan:
metrics = Bouts.calculate_iou_metrics(iou_mat, thr)
perf_rows.append(
{"animal": cur_animal, "video": cur_video, "threshold": thr, **metrics}
)

if len(perf_rows) == 0:
return None

df = pd.DataFrame(perf_rows)
df = df.groupby(["threshold"])[["tp", "fn", "fp"]].apply(np.sum).reset_index()
df["pr"] = df["tp"] / (df["tp"] + df["fp"]) if "tp" in df and "fp" in df else np.nan
df["re"] = df["tp"] / (df["tp"] + df["fn"]) if "tp" in df and "fn" in df else np.nan
df["f1"] = 2 * (df["pr"] * df["re"]) / (df["pr"] + df["re"])
return df
Loading