Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/jabs/classifier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@
"""

from .classifier import Classifier
from .training_report import (
CrossValidationResult,
TrainingReportData,
generate_markdown_report,
markdown_to_html,
save_training_report,
)

__all__ = [
"Classifier",
"CrossValidationResult",
"TrainingReportData",
"generate_markdown_report",
"markdown_to_html",
"save_training_report",
]
21 changes: 16 additions & 5 deletions src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,22 +604,33 @@ def combine_data(per_frame, window):
"""
return pd.concat([per_frame, window], axis=1)

def print_feature_importance(self, feature_list, limit=20):
"""print the most important features and their importance
def get_feature_importance(self, limit=20) -> list[tuple[str, float]]:
"""get the most important features and their importance

Args:
feature_list: list of feature names used in the classifier
limit: maximum number of features to print, defaults to 20
limit: maximum number of features to return, defaults to 20

Returns:
list of tuples of feature name and importance
"""
# Get numerical feature importance
importances = list(self._classifier.feature_importances_)
# List of tuples with variable and importance
feature_importance = [
(feature, round(importance, 2))
for feature, importance in zip(feature_list, importances, strict=True)
for feature, importance in zip(self._feature_names, importances, strict=True)
]
# Sort the feature importance by most important first
feature_importance = sorted(feature_importance, key=lambda x: x[1], reverse=True)
return feature_importance[:limit]

def print_feature_importance(self, limit=20):
"""print the most important features and their importance

Args:
limit: maximum number of features to print, defaults to 20
"""
feature_importance = self.get_feature_importance(limit=limit)
# Print out the feature and importance
print(f"{'Feature Name':100} Importance")
print("-" * 120)
Expand Down
287 changes: 287 additions & 0 deletions src/jabs/classifier/training_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Training report generation for classifier cross-validation results."""

from dataclasses import dataclass, field
from pathlib import Path
from textwrap import dedent

import markdown2
import numpy as np
from tabulate import tabulate


@dataclass
class CrossValidationResult:
"""Results from a single cross-validation iteration.

Attributes:
iteration: The iteration number (1-indexed)
test_video: Name of the video used for testing
test_identity: Identity label used for testing
accuracy: Classification accuracy (0.0 to 1.0)
precision_behavior: Precision for behavior class
precision_not_behavior: Precision for not-behavior class
recall_behavior: Recall for behavior class
recall_not_behavior: Recall for not-behavior class
f1_behavior: F1 score for behavior class
support_behavior: Number of behavior frames in test set
support_not_behavior: Number of not-behavior frames in test set
confusion_matrix: 2x2 confusion matrix
top_features: List of (feature_name, importance) tuples for this iteration
"""

iteration: int
test_video: str
test_identity: str
accuracy: float
precision_behavior: float
precision_not_behavior: float
recall_behavior: float
recall_not_behavior: float
f1_behavior: float
support_behavior: int
support_not_behavior: int
confusion_matrix: np.ndarray
top_features: list[tuple[str, float]] = field(default_factory=list)


@dataclass
class TrainingReportData:
"""Complete training information for generating a report.

Attributes:
behavior_name: Name of the behavior being trained
classifier_type: Type/name of the classifier (e.g., "Random Forest")
distance_unit: Unit used for distance features ("cm" or "pixel")
cv_results: List of CrossValidationResult objects, one per iteration
final_top_features: Top features from final model (trained on all data)
frames_behavior: Total number of frames labeled as behavior
frames_not_behavior: Total number of frames labeled as not behavior
bouts_behavior: Total number of behavior bouts labeled
bouts_not_behavior: Total number of not-behavior bouts labeled
training_time_ms: Total training time in milliseconds
"""

behavior_name: str
classifier_type: str
distance_unit: str
cv_results: list[CrossValidationResult]
final_top_features: list[tuple[str, float]]
frames_behavior: int
frames_not_behavior: int
bouts_behavior: int
bouts_not_behavior: int
training_time_ms: int


def _escape_markdown(text: str) -> str:
"""Escape markdown special characters in text.

Args:
text: Text that may contain markdown special characters

Returns:
Text with markdown special characters escaped
"""
# Escape common markdown characters that might appear in filenames
# Most important: _ (underscore) which creates italics
# Also escape: * (asterisk), [ ] (brackets), ( ) (parentheses)
chars_to_escape = ["_", "*", "[", "]", "(", ")", "`", "#"]
for char in chars_to_escape:
text = text.replace(char, f"\\{char}")
return text


def generate_markdown_report(data: TrainingReportData) -> str:
"""Generate a markdown-formatted training report.

Args:
data: TrainingData object containing all training information

Returns:
Markdown-formatted string
"""
lines = []

lines.append(f"# Training Report: {data.behavior_name}")
lines.append("")

lines.append("## Training Summary")
lines.append("")
lines.append(f"- **Behavior:** {data.behavior_name}")
lines.append(f"- **Classifier:** {data.classifier_type}")
lines.append(f"- **Distance Unit:** {data.distance_unit}")
lines.append(f"- **Training Time:** {data.training_time_ms / 1000:.2f} seconds")
lines.append("")

lines.append("### Label Counts")
lines.append("")
lines.append(f"- **Behavior frames:** {data.frames_behavior:,}")
lines.append(f"- **Not-behavior frames:** {data.frames_not_behavior:,}")
lines.append(f"- **Behavior bouts:** {data.bouts_behavior:,}")
lines.append(f"- **Not-behavior bouts:** {data.bouts_not_behavior:,}")
lines.append("")

# Cross-validation results
if data.cv_results:
lines.append("## Cross-Validation Results")
lines.append("")

# Summary statistics
accuracies = [r.accuracy for r in data.cv_results]
f1_behavior = [r.f1_behavior for r in data.cv_results]

lines.append("### Performance Summary")
lines.append("")
lines.append(
f"- **Mean Accuracy:** {np.mean(accuracies):.4f} (± {np.std(accuracies):.4f})"
)
lines.append(
f"- **Mean F1 Score (Behavior):** {np.mean(f1_behavior):.4f} (± {np.std(f1_behavior):.4f})"
)
lines.append("")

# Detailed results table
lines.append("### Iteration Details")
lines.append("")

table_data = []
for result in data.cv_results:
# Escape markdown special characters in video
escaped_video = _escape_markdown(result.test_video)

table_data.append(
[
result.iteration,
f"{result.accuracy:.4f}",
f"{result.precision_not_behavior:.4f}",
f"{result.precision_behavior:.4f}",
f"{result.recall_not_behavior:.4f}",
f"{result.recall_behavior:.4f}",
f"{result.f1_behavior:.4f}",
f"{escaped_video} [{result.test_identity}]",
]
)

headers = [
"Iter",
"Accuracy",
"Precision (Not Behavior)",
"Precision (Behavior)",
"Recall (Not Behavior)",
"Recall (Behavior)",
"F1 Score",
"Test Group (Video [Identity])",
]

table_markdown = tabulate(table_data, headers=headers, tablefmt="github")
lines.append(table_markdown)
lines.append("")
else:
# No cross-validation was performed
lines.append("## Cross-Validation")
lines.append("")
lines.append("*No cross-validation was performed for this training.*")
lines.append("")

# Final model feature importance
lines.append("## Feature Importance")
lines.append("")
lines.append("Top 20 features from final model (trained on all labeled data):")
lines.append("")

feature_table = []
for rank, (feature_name, importance) in enumerate(data.final_top_features, start=1):
feature_table.append([rank, _escape_markdown(feature_name), f"{importance:.2f}"])

feature_markdown = tabulate(
feature_table, headers=["Rank", "Feature Name", "Importance"], tablefmt="github"
)
lines.append(feature_markdown)
lines.append("")

return "\n".join(lines)


def markdown_to_html(markdown_text: str) -> str:
"""Convert markdown text to HTML.

Args:
markdown_text: Markdown-formatted string

Returns:
HTML string with basic styling
"""
html_content = markdown2.markdown(markdown_text, extras=["tables", "fenced-code-blocks"])

# Wrap in basic HTML document with styling
html = dedent(f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif;
line-height: 1.6;
max-width: 1200px;
margin: 20px auto;
padding: 0 20px;
color: #333;
}}
h1 {{
border-bottom: 2px solid #333;
padding-bottom: 10px;
}}
h2 {{
border-bottom: 1px solid #ccc;
padding-bottom: 8px;
margin-top: 30px;
}}
h3 {{
margin-top: 20px;
}}
table {{
border-collapse: collapse;
width: 100%;
margin: 20px 0;
}}
th, td {{
border: 1px solid #ddd;
padding: 8px;
text-align: left;
}}
th {{
background-color: #f2f2f2;
font-weight: bold;
}}
tr:nth-child(even) {{
background-color: #f9f9f9;
}}
code {{
background-color: #f4f4f4;
padding: 2px 4px;
border-radius: 3px;
}}
ul {{
line-height: 1.8;
}}
</style>
</head>
<body>
{html_content}
</body>
</html>
""").strip()
return html


def save_training_report(data: TrainingReportData, output_path: Path) -> None:
"""Generate and save a training report as markdown.

Args:
data: TrainingData object containing all training information
output_path: Path where the markdown file should be saved
"""
markdown_content = generate_markdown_report(data)
with open(output_path, "w", encoding="utf-8") as f:
f.write(markdown_content)
7 changes: 7 additions & 0 deletions src/jabs/project/project_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, base_path: Path, use_cache: bool = True):
self._archive_dir = self._jabs_dir / "archive"
self._session_dir = self._jabs_dir / "session"
self._cache_dir = self._jabs_dir / "cache" if use_cache else None
self._training_log_dir = self._jabs_dir / "training_logs"

self._project_file = self._jabs_dir / self.__PROJECT_FILE

Expand Down Expand Up @@ -71,6 +72,11 @@ def session_dir(self) -> Path:
"""Get the path to the session directory."""
return self._session_dir

@property
def training_log_dir(self) -> Path:
"""Get the path to the training logs directory."""
return self._training_log_dir

def create_directories(self, validate: bool = True) -> None:
"""Create all necessary directories for the project.

Expand Down Expand Up @@ -99,6 +105,7 @@ def create_directories(self, validate: bool = True) -> None:
self._classifier_dir.mkdir(parents=True, exist_ok=True)
self._archive_dir.mkdir(parents=True, exist_ok=True)
self._session_dir.mkdir(parents=True, exist_ok=True)
self._training_log_dir.mkdir(parents=True, exist_ok=True)

if self._cache_dir:
self._cache_dir.mkdir(parents=True, exist_ok=True)
3 changes: 0 additions & 3 deletions src/jabs/project/session_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def classifier_trained(
k: int,
accuracy: float | None = None,
fbeta_behavior: float | None = None,
fbeta_notbehavior: float | None = None,
):
"""Log the training of a classifier."""
if not self._tracking_enabled or not self._session:
Expand All @@ -299,8 +298,6 @@ def classifier_trained(
activity["mean accuracy"] = f"{accuracy:.3}"
if fbeta_behavior is not None:
activity["mean fbeta (behavior)"] = f"{fbeta_behavior:.3}"
if fbeta_notbehavior is not None:
activity["mean fbeta (not behavior)"] = f"{fbeta_notbehavior:.3}"

self._session["activity_log"].append(activity)
self._flush_session()
Expand Down
Loading