Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/jabs/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
"""

from .classifier import Classifier
from .training_report import (
CrossValidationResult,
TrainingReportData,
generate_markdown_report,
save_training_report,
)

__all__ = [
"Classifier",
"CrossValidationResult",
"TrainingReportData",
"generate_markdown_report",
"save_training_report",
]
21 changes: 16 additions & 5 deletions src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,22 +604,33 @@ def combine_data(per_frame, window):
"""
return pd.concat([per_frame, window], axis=1)

def print_feature_importance(self, feature_list, limit=20):
"""print the most important features and their importance
def get_feature_importance(self, limit=20) -> list[tuple[str, float]]:
"""get the most important features and their importance

Args:
feature_list: list of feature names used in the classifier
limit: maximum number of features to print, defaults to 20
limit: maximum number of features to return, defaults to 20

Returns:
list of tuples of feature name and importance
"""
# Get numerical feature importance
importances = list(self._classifier.feature_importances_)
# List of tuples with variable and importance
feature_importance = [
(feature, round(importance, 2))
for feature, importance in zip(feature_list, importances, strict=True)
for feature, importance in zip(self._feature_names, importances, strict=True)
]
# Sort the feature importance by most important first
feature_importance = sorted(feature_importance, key=lambda x: x[1], reverse=True)
return feature_importance[:limit]

def print_feature_importance(self, limit=20):
"""print the most important features and their importance

Args:
limit: maximum number of features to print, defaults to 20
"""
feature_importance = self.get_feature_importance(limit=limit)
# Print out the feature and importance
print(f"{'Feature Name':100} Importance")
print("-" * 120)
Expand Down
228 changes: 228 additions & 0 deletions src/jabs/classifier/training_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
"""Training report generation for classifier cross-validation results."""

from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path

import numpy as np
from tabulate import tabulate


@dataclass
class CrossValidationResult:
"""Results from a single cross-validation iteration.

Attributes:
iteration: The iteration number (1-indexed)
test_video: Name of the video used for testing
test_identity: Identity label used for testing
accuracy: Classification accuracy (0.0 to 1.0)
precision_behavior: Precision for behavior class
precision_not_behavior: Precision for not-behavior class
recall_behavior: Recall for behavior class
recall_not_behavior: Recall for not-behavior class
f1_behavior: F1 score for behavior class
support_behavior: Number of behavior frames in test set
support_not_behavior: Number of not-behavior frames in test set
confusion_matrix: 2x2 confusion matrix
top_features: List of (feature_name, importance) tuples for this iteration
"""

iteration: int
test_video: str
test_identity: str
accuracy: float
precision_behavior: float
precision_not_behavior: float
recall_behavior: float
recall_not_behavior: float
f1_behavior: float
support_behavior: int
support_not_behavior: int
confusion_matrix: np.ndarray
top_features: list[tuple[str, float]] = field(default_factory=list)


@dataclass
class TrainingReportData:
"""Complete training information for generating a report.

Attributes:
behavior_name: Name of the behavior being trained
classifier_type: Type/name of the classifier (e.g., "Random Forest")
window_size: Window size used for feature extraction
balance_training_labels: Whether training labels were balanced
symmetric_behavior: Whether the behavior is symmetric
distance_unit: Unit used for distance features ("cm" or "pixel")
cv_results: List of CrossValidationResult objects, one per iteration
final_top_features: Top features from final model (trained on all data)
frames_behavior: Total number of frames labeled as behavior
frames_not_behavior: Total number of frames labeled as not behavior
bouts_behavior: Total number of behavior bouts labeled
bouts_not_behavior: Total number of not-behavior bouts labeled
training_time_ms: Total training time in milliseconds
timestamp: Datetime when training was completed
"""

behavior_name: str
classifier_type: str
window_size: int
balance_training_labels: bool
symmetric_behavior: bool
distance_unit: str
cv_results: list[CrossValidationResult]
final_top_features: list[tuple[str, float]]
frames_behavior: int
frames_not_behavior: int
bouts_behavior: int
bouts_not_behavior: int
training_time_ms: int
timestamp: datetime


def _escape_markdown(text: str) -> str:
"""Escape markdown special characters in text.

Args:
text: Text that may contain markdown special characters

Returns:
Text with markdown special characters escaped
"""
# Escape common markdown characters that might appear in filenames
# Most important: _ (underscore) which creates italics
# Also escape: * (asterisk), [ ] (brackets), ( ) (parentheses)
chars_to_escape = ["_", "*", "[", "]", "(", ")", "`", "#"]
for char in chars_to_escape:
text = text.replace(char, f"\\{char}")
return text


def generate_markdown_report(data: TrainingReportData) -> str:
"""Generate a markdown-formatted training report.

Args:
data: TrainingData object containing all training information

Returns:
Markdown-formatted string
"""
lines = []

lines.append(f"# Training Report: {data.behavior_name}")
lines.append("")
lines.append(f"**Date:** {data.timestamp.strftime('%B %d, %Y at %I:%M:%S %p')}")
lines.append("")

lines.append("## Training Summary")
lines.append("")
lines.append(f"- **Behavior:** {data.behavior_name}")
lines.append(f"- **Classifier:** {data.classifier_type}")
lines.append(f"- **Window Size:** {data.window_size}")
lines.append(
f"- **Balanced Training Labels:** {'Yes' if data.balance_training_labels else 'No'}"
)
lines.append(f"- **Symmetric Behavior:** {'Yes' if data.symmetric_behavior else 'No'}")
lines.append(f"- **Distance Unit:** {data.distance_unit}")
lines.append(f"- **Training Time:** {data.training_time_ms / 1000:.2f} seconds")
lines.append("")

lines.append("### Label Counts")
lines.append("")
lines.append(f"- **Behavior frames:** {data.frames_behavior:,}")
lines.append(f"- **Not-behavior frames:** {data.frames_not_behavior:,}")
lines.append(f"- **Behavior bouts:** {data.bouts_behavior:,}")
lines.append(f"- **Not-behavior bouts:** {data.bouts_not_behavior:,}")
lines.append("")

# Cross-validation results
if data.cv_results:
lines.append("## Cross-Validation Results")
lines.append("")

# Summary statistics
accuracies = [r.accuracy for r in data.cv_results]
f1_behavior = [r.f1_behavior for r in data.cv_results]

lines.append("### Performance Summary")
lines.append("")
lines.append(
f"- **Mean Accuracy:** {np.mean(accuracies):.4f} (± {np.std(accuracies):.4f})"
)
lines.append(
f"- **Mean F1 Score (Behavior):** {np.mean(f1_behavior):.4f} (± {np.std(f1_behavior):.4f})"
)
lines.append("")

# Detailed results table
lines.append("### Iteration Details")
lines.append("")

table_data = []
for result in data.cv_results:
# Escape markdown special characters in video
escaped_video = _escape_markdown(result.test_video)

table_data.append(
[
result.iteration,
f"{result.accuracy:.4f}",
f"{result.precision_not_behavior:.4f}",
f"{result.precision_behavior:.4f}",
f"{result.recall_not_behavior:.4f}",
f"{result.recall_behavior:.4f}",
f"{result.f1_behavior:.4f}",
f"{escaped_video} [{result.test_identity}]",
]
)

headers = [
"Iter",
"Accuracy",
"Precision (Not Behavior)",
"Precision (Behavior)",
"Recall (Not Behavior)",
"Recall (Behavior)",
"F1 Score",
"Test Group (Video [Identity])",
]

table_markdown = tabulate(table_data, headers=headers, tablefmt="github")
lines.append(table_markdown)
lines.append("")
else:
# No cross-validation was performed
lines.append("## Cross-Validation")
lines.append("")
lines.append("*No cross-validation was performed for this training.*")
lines.append("")

# Final model feature importance
lines.append("## Feature Importance")
lines.append("")
lines.append("Top 20 features from final model (trained on all labeled data):")
lines.append("")

feature_table = []
for rank, (feature_name, importance) in enumerate(data.final_top_features, start=1):
feature_table.append([rank, _escape_markdown(feature_name), f"{importance:.2f}"])

feature_markdown = tabulate(
feature_table, headers=["Rank", "Feature Name", "Importance"], tablefmt="github"
)
lines.append(feature_markdown)
lines.append("")

return "\n".join(lines)


def save_training_report(data: TrainingReportData, output_path: Path) -> None:
"""Generate and save a training report as markdown.

Args:
data: TrainingData object containing all training information
output_path: Path where the markdown file should be saved
"""
markdown_content = generate_markdown_report(data)
with open(output_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
7 changes: 7 additions & 0 deletions src/jabs/project/project_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, base_path: Path, use_cache: bool = True):
self._archive_dir = self._jabs_dir / "archive"
self._session_dir = self._jabs_dir / "session"
self._cache_dir = self._jabs_dir / "cache" if use_cache else None
self._training_log_dir = self._jabs_dir / "training_logs"

self._project_file = self._jabs_dir / self.__PROJECT_FILE

Expand Down Expand Up @@ -71,6 +72,11 @@ def session_dir(self) -> Path:
"""Get the path to the session directory."""
return self._session_dir

@property
def training_log_dir(self) -> Path:
"""Get the path to the training logs directory."""
return self._training_log_dir

def create_directories(self, validate: bool = True) -> None:
"""Create all necessary directories for the project.

Expand Down Expand Up @@ -99,6 +105,7 @@ def create_directories(self, validate: bool = True) -> None:
self._classifier_dir.mkdir(parents=True, exist_ok=True)
self._archive_dir.mkdir(parents=True, exist_ok=True)
self._session_dir.mkdir(parents=True, exist_ok=True)
self._training_log_dir.mkdir(parents=True, exist_ok=True)

if self._cache_dir:
self._cache_dir.mkdir(parents=True, exist_ok=True)
3 changes: 0 additions & 3 deletions src/jabs/project/session_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def classifier_trained(
k: int,
accuracy: float | None = None,
fbeta_behavior: float | None = None,
fbeta_notbehavior: float | None = None,
):
"""Log the training of a classifier."""
if not self._tracking_enabled or not self._session:
Expand All @@ -299,8 +298,6 @@ def classifier_trained(
activity["mean accuracy"] = f"{accuracy:.3}"
if fbeta_behavior is not None:
activity["mean fbeta (behavior)"] = f"{fbeta_behavior:.3}"
if fbeta_notbehavior is not None:
activity["mean fbeta (not behavior)"] = f"{fbeta_notbehavior:.3}"

self._session["activity_log"].append(activity)
self._flush_session()
Expand Down
23 changes: 22 additions & 1 deletion src/jabs/resources/docs/user_guide/gui.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

<img src="imgs/classifier_controls.png" alt="JABS Classifier Controls" width=900 />

- **Train Button:** Train the classifier with the current parameters. This button is disabled until minimum number of frames have been labeled for a minimum number of mice (increasing the cross validation k parameter increases the minimum number of labeled mice)
- **Train Button:** Train the classifier with the current parameters. This button is disabled until minimum number of frames have been labeled for a minimum number of mice (increasing the cross validation k parameter increases the minimum number of labeled mice). When training completes, a training report dialog will display performance metrics including cross-validation results and feature importance rankings.
- **Classify Button:** Infer class of unlabeled frames. Disabled until classifier is trained. Changing classifier parameters may require retraining before the Classify button becomes active again.
- **Classifier Type Selection:** Users can select from a list of supported classifiers.
- **Window Size Selection:** Number of frames on each side of the current frame to include in window feature calculations for that frame. A "window size" of 5 means that 11 frames are included into the window feature calculations for each frame (5 previous frames, current frame, 5 following frames).
Expand Down Expand Up @@ -113,6 +113,22 @@ XGBoost is another gradient boosting algorithm known for winning machine learnin

**Note:** The actual performance difference between classifiers varies by behavior type and dataset. We recommend testing multiple classifiers on your specific data to find the best option for your use case.

### Training Reports

When training completes, JABS displays a training report in a modal dialog. The report includes:

- **Training summary** - behavior name, classifier type, distance unit, and training time
- **Label counts** - Number of labeled frames and bouts for both behavior and not-behavior classes
- **Cross-validation results** - Performance metrics (accuracy, precision, recall, F1 score) for each leave-one-out iteration, along with which video/identity was held out as the test set
- **Feature importance** - Top 20 most important features from the final trained classifier

The training report is also saved as a Markdown file in the `jabs/training_logs` directory within your project. The filename includes the behavior name and timestamp (e.g., `Grooming_20260102_143022_training_report.md`). These reports provide a permanent record of your training sessions and can be useful for:

- Comparing different classifier configurations
- Identifying problematic videos or identities with poor cross-validation performance
- Understanding which features contribute most to behavior classification
- Documenting your analysis workflow

## Timeline Visualizations

<img src="imgs/label_viz.png" alt="JABS Label Visualizations" width=900 />
Expand Down Expand Up @@ -174,6 +190,11 @@ Clicking the Brightness or Contrast controls will reset the brightness or contra
- **Features→Enable Lixit Features:** toggle using lixit features (v5+ projects with lixit static object)
- **Features→Enable Food_hopper Features:** toggle using food hopper features (v5+ projects with food hopper static object)
- **Features→Enable Segmentation Features:** toggle using segmentation features (v6+ projects)
- **Window:** Menu for managing JABS windows.
- **Window→Minimize:** Minimize the main window (⌘M on macOS, Ctrl+M on other platforms)
- **Window→Zoom:** Toggle between normal and maximized window state
- **Window→Bring All to Front:** Bring all JABS windows to the front
- The Window menu also displays a list of all open JABS windows (main window, user guide, training reports, etc.) with a checkmark (✓) next to the currently active window. Click any window in the list to activate and bring it to the front.

## Overlays

Expand Down
Loading