Skip to content

Commit 2970ba5

Browse files
authored
Merge pull request #260 from KumarLabJax/add-training-report-dialog
implement training report
2 parents 0a40b0f + 8274180 commit 2970ba5

File tree

17 files changed

+1347
-100
lines changed

17 files changed

+1347
-100
lines changed

src/jabs/classifier/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66
"""
77

88
from .classifier import Classifier
9+
from .training_report import (
10+
CrossValidationResult,
11+
TrainingReportData,
12+
generate_markdown_report,
13+
save_training_report,
14+
)
915

1016
__all__ = [
1117
"Classifier",
18+
"CrossValidationResult",
19+
"TrainingReportData",
20+
"generate_markdown_report",
21+
"save_training_report",
1222
]

src/jabs/classifier/classifier.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -604,22 +604,33 @@ def combine_data(per_frame, window):
604604
"""
605605
return pd.concat([per_frame, window], axis=1)
606606

607-
def print_feature_importance(self, feature_list, limit=20):
608-
"""print the most important features and their importance
607+
def get_feature_importance(self, limit=20) -> list[tuple[str, float]]:
608+
"""get the most important features and their importance
609609
610610
Args:
611-
feature_list: list of feature names used in the classifier
612-
limit: maximum number of features to print, defaults to 20
611+
limit: maximum number of features to return, defaults to 20
612+
613+
Returns:
614+
list of tuples of feature name and importance
613615
"""
614616
# Get numerical feature importance
615617
importances = list(self._classifier.feature_importances_)
616618
# List of tuples with variable and importance
617619
feature_importance = [
618620
(feature, round(importance, 2))
619-
for feature, importance in zip(feature_list, importances, strict=True)
621+
for feature, importance in zip(self._feature_names, importances, strict=True)
620622
]
621623
# Sort the feature importance by most important first
622624
feature_importance = sorted(feature_importance, key=lambda x: x[1], reverse=True)
625+
return feature_importance[:limit]
626+
627+
def print_feature_importance(self, limit=20):
628+
"""print the most important features and their importance
629+
630+
Args:
631+
limit: maximum number of features to print, defaults to 20
632+
"""
633+
feature_importance = self.get_feature_importance(limit=limit)
623634
# Print out the feature and importance
624635
print(f"{'Feature Name':100} Importance")
625636
print("-" * 120)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""Training report generation for classifier cross-validation results."""
2+
3+
from dataclasses import dataclass, field
4+
from datetime import datetime
5+
from pathlib import Path
6+
7+
import numpy as np
8+
from tabulate import tabulate
9+
10+
11+
@dataclass
12+
class CrossValidationResult:
13+
"""Results from a single cross-validation iteration.
14+
15+
Attributes:
16+
iteration: The iteration number (1-indexed)
17+
test_video: Name of the video used for testing
18+
test_identity: Identity label used for testing
19+
accuracy: Classification accuracy (0.0 to 1.0)
20+
precision_behavior: Precision for behavior class
21+
precision_not_behavior: Precision for not-behavior class
22+
recall_behavior: Recall for behavior class
23+
recall_not_behavior: Recall for not-behavior class
24+
f1_behavior: F1 score for behavior class
25+
support_behavior: Number of behavior frames in test set
26+
support_not_behavior: Number of not-behavior frames in test set
27+
confusion_matrix: 2x2 confusion matrix
28+
top_features: List of (feature_name, importance) tuples for this iteration
29+
"""
30+
31+
iteration: int
32+
test_video: str
33+
test_identity: str
34+
accuracy: float
35+
precision_behavior: float
36+
precision_not_behavior: float
37+
recall_behavior: float
38+
recall_not_behavior: float
39+
f1_behavior: float
40+
support_behavior: int
41+
support_not_behavior: int
42+
confusion_matrix: np.ndarray
43+
top_features: list[tuple[str, float]] = field(default_factory=list)
44+
45+
46+
@dataclass
47+
class TrainingReportData:
48+
"""Complete training information for generating a report.
49+
50+
Attributes:
51+
behavior_name: Name of the behavior being trained
52+
classifier_type: Type/name of the classifier (e.g., "Random Forest")
53+
window_size: Window size used for feature extraction
54+
balance_training_labels: Whether training labels were balanced
55+
symmetric_behavior: Whether the behavior is symmetric
56+
distance_unit: Unit used for distance features ("cm" or "pixel")
57+
cv_results: List of CrossValidationResult objects, one per iteration
58+
final_top_features: Top features from final model (trained on all data)
59+
frames_behavior: Total number of frames labeled as behavior
60+
frames_not_behavior: Total number of frames labeled as not behavior
61+
bouts_behavior: Total number of behavior bouts labeled
62+
bouts_not_behavior: Total number of not-behavior bouts labeled
63+
training_time_ms: Total training time in milliseconds
64+
timestamp: Datetime when training was completed
65+
"""
66+
67+
behavior_name: str
68+
classifier_type: str
69+
window_size: int
70+
balance_training_labels: bool
71+
symmetric_behavior: bool
72+
distance_unit: str
73+
cv_results: list[CrossValidationResult]
74+
final_top_features: list[tuple[str, float]]
75+
frames_behavior: int
76+
frames_not_behavior: int
77+
bouts_behavior: int
78+
bouts_not_behavior: int
79+
training_time_ms: int
80+
timestamp: datetime
81+
82+
83+
def _escape_markdown(text: str) -> str:
84+
"""Escape markdown special characters in text.
85+
86+
Args:
87+
text: Text that may contain markdown special characters
88+
89+
Returns:
90+
Text with markdown special characters escaped
91+
"""
92+
# Escape common markdown characters that might appear in filenames
93+
# Most important: _ (underscore) which creates italics
94+
# Also escape: * (asterisk), [ ] (brackets), ( ) (parentheses)
95+
chars_to_escape = ["_", "*", "[", "]", "(", ")", "`", "#"]
96+
for char in chars_to_escape:
97+
text = text.replace(char, f"\\{char}")
98+
return text
99+
100+
101+
def generate_markdown_report(data: TrainingReportData) -> str:
102+
"""Generate a markdown-formatted training report.
103+
104+
Args:
105+
data: TrainingData object containing all training information
106+
107+
Returns:
108+
Markdown-formatted string
109+
"""
110+
lines = []
111+
112+
lines.append(f"# Training Report: {data.behavior_name}")
113+
lines.append("")
114+
lines.append(f"**Date:** {data.timestamp.strftime('%B %d, %Y at %I:%M:%S %p')}")
115+
lines.append("")
116+
117+
lines.append("## Training Summary")
118+
lines.append("")
119+
lines.append(f"- **Behavior:** {data.behavior_name}")
120+
lines.append(f"- **Classifier:** {data.classifier_type}")
121+
lines.append(f"- **Window Size:** {data.window_size}")
122+
lines.append(
123+
f"- **Balanced Training Labels:** {'Yes' if data.balance_training_labels else 'No'}"
124+
)
125+
lines.append(f"- **Symmetric Behavior:** {'Yes' if data.symmetric_behavior else 'No'}")
126+
lines.append(f"- **Distance Unit:** {data.distance_unit}")
127+
lines.append(f"- **Training Time:** {data.training_time_ms / 1000:.2f} seconds")
128+
lines.append("")
129+
130+
lines.append("### Label Counts")
131+
lines.append("")
132+
lines.append(f"- **Behavior frames:** {data.frames_behavior:,}")
133+
lines.append(f"- **Not-behavior frames:** {data.frames_not_behavior:,}")
134+
lines.append(f"- **Behavior bouts:** {data.bouts_behavior:,}")
135+
lines.append(f"- **Not-behavior bouts:** {data.bouts_not_behavior:,}")
136+
lines.append("")
137+
138+
# Cross-validation results
139+
if data.cv_results:
140+
lines.append("## Cross-Validation Results")
141+
lines.append("")
142+
143+
# Summary statistics
144+
accuracies = [r.accuracy for r in data.cv_results]
145+
f1_behavior = [r.f1_behavior for r in data.cv_results]
146+
147+
lines.append("### Performance Summary")
148+
lines.append("")
149+
lines.append(
150+
f"- **Mean Accuracy:** {np.mean(accuracies):.4f}{np.std(accuracies):.4f})"
151+
)
152+
lines.append(
153+
f"- **Mean F1 Score (Behavior):** {np.mean(f1_behavior):.4f}{np.std(f1_behavior):.4f})"
154+
)
155+
lines.append("")
156+
157+
# Detailed results table
158+
lines.append("### Iteration Details")
159+
lines.append("")
160+
161+
table_data = []
162+
for result in data.cv_results:
163+
# Escape markdown special characters in video
164+
escaped_video = _escape_markdown(result.test_video)
165+
166+
table_data.append(
167+
[
168+
result.iteration,
169+
f"{result.accuracy:.4f}",
170+
f"{result.precision_not_behavior:.4f}",
171+
f"{result.precision_behavior:.4f}",
172+
f"{result.recall_not_behavior:.4f}",
173+
f"{result.recall_behavior:.4f}",
174+
f"{result.f1_behavior:.4f}",
175+
f"{escaped_video} [{result.test_identity}]",
176+
]
177+
)
178+
179+
headers = [
180+
"Iter",
181+
"Accuracy",
182+
"Precision (Not Behavior)",
183+
"Precision (Behavior)",
184+
"Recall (Not Behavior)",
185+
"Recall (Behavior)",
186+
"F1 Score",
187+
"Test Group (Video [Identity])",
188+
]
189+
190+
table_markdown = tabulate(table_data, headers=headers, tablefmt="github")
191+
lines.append(table_markdown)
192+
lines.append("")
193+
else:
194+
# No cross-validation was performed
195+
lines.append("## Cross-Validation")
196+
lines.append("")
197+
lines.append("*No cross-validation was performed for this training.*")
198+
lines.append("")
199+
200+
# Final model feature importance
201+
lines.append("## Feature Importance")
202+
lines.append("")
203+
lines.append("Top 20 features from final model (trained on all labeled data):")
204+
lines.append("")
205+
206+
feature_table = []
207+
for rank, (feature_name, importance) in enumerate(data.final_top_features, start=1):
208+
feature_table.append([rank, _escape_markdown(feature_name), f"{importance:.2f}"])
209+
210+
feature_markdown = tabulate(
211+
feature_table, headers=["Rank", "Feature Name", "Importance"], tablefmt="github"
212+
)
213+
lines.append(feature_markdown)
214+
lines.append("")
215+
216+
return "\n".join(lines)
217+
218+
219+
def save_training_report(data: TrainingReportData, output_path: Path) -> None:
220+
"""Generate and save a training report as markdown.
221+
222+
Args:
223+
data: TrainingData object containing all training information
224+
output_path: Path where the markdown file should be saved
225+
"""
226+
markdown_content = generate_markdown_report(data)
227+
with open(output_path, "w", encoding="utf-8") as f:
228+
f.write(markdown_content)

src/jabs/project/project_paths.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, base_path: Path, use_cache: bool = True):
1818
self._archive_dir = self._jabs_dir / "archive"
1919
self._session_dir = self._jabs_dir / "session"
2020
self._cache_dir = self._jabs_dir / "cache" if use_cache else None
21+
self._training_log_dir = self._jabs_dir / "training_logs"
2122

2223
self._project_file = self._jabs_dir / self.__PROJECT_FILE
2324

@@ -71,6 +72,11 @@ def session_dir(self) -> Path:
7172
"""Get the path to the session directory."""
7273
return self._session_dir
7374

75+
@property
76+
def training_log_dir(self) -> Path:
77+
"""Get the path to the training logs directory."""
78+
return self._training_log_dir
79+
7480
def create_directories(self, validate: bool = True) -> None:
7581
"""Create all necessary directories for the project.
7682
@@ -99,6 +105,7 @@ def create_directories(self, validate: bool = True) -> None:
99105
self._classifier_dir.mkdir(parents=True, exist_ok=True)
100106
self._archive_dir.mkdir(parents=True, exist_ok=True)
101107
self._session_dir.mkdir(parents=True, exist_ok=True)
108+
self._training_log_dir.mkdir(parents=True, exist_ok=True)
102109

103110
if self._cache_dir:
104111
self._cache_dir.mkdir(parents=True, exist_ok=True)

src/jabs/project/session_tracker.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def classifier_trained(
281281
k: int,
282282
accuracy: float | None = None,
283283
fbeta_behavior: float | None = None,
284-
fbeta_notbehavior: float | None = None,
285284
):
286285
"""Log the training of a classifier."""
287286
if not self._tracking_enabled or not self._session:
@@ -299,8 +298,6 @@ def classifier_trained(
299298
activity["mean accuracy"] = f"{accuracy:.3}"
300299
if fbeta_behavior is not None:
301300
activity["mean fbeta (behavior)"] = f"{fbeta_behavior:.3}"
302-
if fbeta_notbehavior is not None:
303-
activity["mean fbeta (not behavior)"] = f"{fbeta_notbehavior:.3}"
304301

305302
self._session["activity_log"].append(activity)
306303
self._flush_session()

src/jabs/resources/docs/user_guide/gui.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
<img src="imgs/classifier_controls.png" alt="JABS Classifier Controls" width=900 />
2424

25-
- **Train Button:** Train the classifier with the current parameters. This button is disabled until minimum number of frames have been labeled for a minimum number of mice (increasing the cross validation k parameter increases the minimum number of labeled mice)
25+
- **Train Button:** Train the classifier with the current parameters. This button is disabled until minimum number of frames have been labeled for a minimum number of mice (increasing the cross validation k parameter increases the minimum number of labeled mice). When training completes, a training report dialog will display performance metrics including cross-validation results and feature importance rankings.
2626
- **Classify Button:** Infer class of unlabeled frames. Disabled until classifier is trained. Changing classifier parameters may require retraining before the Classify button becomes active again.
2727
- **Classifier Type Selection:** Users can select from a list of supported classifiers.
2828
- **Window Size Selection:** Number of frames on each side of the current frame to include in window feature calculations for that frame. A "window size" of 5 means that 11 frames are included into the window feature calculations for each frame (5 previous frames, current frame, 5 following frames).
@@ -113,6 +113,22 @@ XGBoost is another gradient boosting algorithm known for winning machine learnin
113113

114114
**Note:** The actual performance difference between classifiers varies by behavior type and dataset. We recommend testing multiple classifiers on your specific data to find the best option for your use case.
115115

116+
### Training Reports
117+
118+
When training completes, JABS displays a training report in a modal dialog. The report includes:
119+
120+
- **Training summary** - behavior name, classifier type, distance unit, and training time
121+
- **Label counts** - Number of labeled frames and bouts for both behavior and not-behavior classes
122+
- **Cross-validation results** - Performance metrics (accuracy, precision, recall, F1 score) for each leave-one-out iteration, along with which video/identity was held out as the test set
123+
- **Feature importance** - Top 20 most important features from the final trained classifier
124+
125+
The training report is also saved as a Markdown file in the `jabs/training_logs` directory within your project. The filename includes the behavior name and timestamp (e.g., `Grooming_20260102_143022_training_report.md`). These reports provide a permanent record of your training sessions and can be useful for:
126+
127+
- Comparing different classifier configurations
128+
- Identifying problematic videos or identities with poor cross-validation performance
129+
- Understanding which features contribute most to behavior classification
130+
- Documenting your analysis workflow
131+
116132
## Timeline Visualizations
117133

118134
<img src="imgs/label_viz.png" alt="JABS Label Visualizations" width=900 />
@@ -174,6 +190,11 @@ Clicking the Brightness or Contrast controls will reset the brightness or contra
174190
- **Features→Enable Lixit Features:** toggle using lixit features (v5+ projects with lixit static object)
175191
- **Features→Enable Food_hopper Features:** toggle using food hopper features (v5+ projects with food hopper static object)
176192
- **Features→Enable Segmentation Features:** toggle using segmentation features (v6+ projects)
193+
- **Window:** Menu for managing JABS windows.
194+
- **Window→Minimize:** Minimize the main window (⌘M on macOS, Ctrl+M on other platforms)
195+
- **Window→Zoom:** Toggle between normal and maximized window state
196+
- **Window→Bring All to Front:** Bring all JABS windows to the front
197+
- The Window menu also displays a list of all open JABS windows (main window, user guide, training reports, etc.) with a checkmark (✓) next to the currently active window. Click any window in the list to activate and bring it to the front.
177198

178199
## Overlays
179200

0 commit comments

Comments
 (0)