-
Notifications
You must be signed in to change notification settings - Fork 4
implement training report #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
d2984a0
implement training report
gbeane d63f738
remove unused variable
gbeane 79d05e1
escape characters in filenames and feature names in the training repo…
gbeane a97f9c5
allow using Cmd-q/Ctrl-q to quit JABS while training report dialog is…
gbeane 807041a
Update src/jabs/resources/docs/user_guide/gui.md
gbeane e23c034
add balance labels and symmetric behavior settings to training report
gbeane 0dd3814
make sure progress dialog is closed before training report is displayed
gbeane c73da0e
add date, window size, and other settings to training report. make di…
gbeane 388f315
move markdown to html rendering into TrainingReportDialog; add copy t…
gbeane b9956ce
improve how JABS handles multiple windows
gbeane 3012c6d
Update src/jabs/ui/user_guide_dialog.py
gbeane 8274180
make Random Forest default classifier for new projects
gbeane File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,266 @@ | ||
| """Training report generation for classifier cross-validation results.""" | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from pathlib import Path | ||
| from textwrap import dedent | ||
|
|
||
| import markdown2 | ||
| import numpy as np | ||
| from tabulate import tabulate | ||
|
|
||
|
|
||
| @dataclass | ||
| class CrossValidationResult: | ||
| """Results from a single cross-validation iteration. | ||
|
|
||
| Attributes: | ||
| iteration: The iteration number (1-indexed) | ||
| test_video: Name of the video used for testing | ||
| test_identity: Identity label used for testing | ||
| accuracy: Classification accuracy (0.0 to 1.0) | ||
| precision_behavior: Precision for behavior class | ||
| precision_not_behavior: Precision for not-behavior class | ||
| recall_behavior: Recall for behavior class | ||
| recall_not_behavior: Recall for not-behavior class | ||
| f1_behavior: F1 score for behavior class | ||
| support_behavior: Number of behavior frames in test set | ||
| support_not_behavior: Number of not-behavior frames in test set | ||
| confusion_matrix: 2x2 confusion matrix | ||
| top_features: List of (feature_name, importance) tuples for this iteration | ||
| """ | ||
|
|
||
| iteration: int | ||
| test_video: str | ||
| test_identity: str | ||
| accuracy: float | ||
| precision_behavior: float | ||
| precision_not_behavior: float | ||
| recall_behavior: float | ||
| recall_not_behavior: float | ||
| f1_behavior: float | ||
| support_behavior: int | ||
| support_not_behavior: int | ||
| confusion_matrix: np.ndarray | ||
| top_features: list[tuple[str, float]] = field(default_factory=list) | ||
|
|
||
|
|
||
| @dataclass | ||
| class TrainingReportData: | ||
| """Complete training information for generating a report. | ||
|
|
||
| Attributes: | ||
| behavior_name: Name of the behavior being trained | ||
| classifier_type: Type/name of the classifier (e.g., "Random Forest") | ||
| distance_unit: Unit used for distance features ("cm" or "pixel") | ||
| cv_results: List of CrossValidationResult objects, one per iteration | ||
| final_top_features: Top features from final model (trained on all data) | ||
| frames_behavior: Total number of frames labeled as behavior | ||
| frames_not_behavior: Total number of frames labeled as not behavior | ||
| bouts_behavior: Total number of behavior bouts labeled | ||
| bouts_not_behavior: Total number of not-behavior bouts labeled | ||
| training_time_ms: Total training time in milliseconds | ||
| """ | ||
|
|
||
| behavior_name: str | ||
| classifier_type: str | ||
| distance_unit: str | ||
| cv_results: list[CrossValidationResult] | ||
| final_top_features: list[tuple[str, float]] | ||
| frames_behavior: int | ||
| frames_not_behavior: int | ||
| bouts_behavior: int | ||
| bouts_not_behavior: int | ||
| training_time_ms: int | ||
|
|
||
|
|
||
| def generate_markdown_report(data: TrainingReportData) -> str: | ||
| """Generate a markdown-formatted training report. | ||
|
|
||
| Args: | ||
| data: TrainingData object containing all training information | ||
|
|
||
| Returns: | ||
| Markdown-formatted string | ||
| """ | ||
| lines = [] | ||
|
|
||
| lines.append(f"# Training Report: {data.behavior_name}") | ||
| lines.append("") | ||
|
|
||
| lines.append("## Training Summary") | ||
| lines.append("") | ||
| lines.append(f"- **Behavior:** {data.behavior_name}") | ||
| lines.append(f"- **Classifier:** {data.classifier_type}") | ||
| lines.append(f"- **Distance Unit:** {data.distance_unit}") | ||
| lines.append(f"- **Training Time:** {data.training_time_ms / 1000:.2f} seconds") | ||
| lines.append("") | ||
|
|
||
| lines.append("### Label Counts") | ||
| lines.append("") | ||
| lines.append(f"- **Behavior frames:** {data.frames_behavior:,}") | ||
| lines.append(f"- **Not-behavior frames:** {data.frames_not_behavior:,}") | ||
| lines.append(f"- **Behavior bouts:** {data.bouts_behavior:,}") | ||
| lines.append(f"- **Not-behavior bouts:** {data.bouts_not_behavior:,}") | ||
| lines.append("") | ||
|
|
||
| # Cross-validation results | ||
| if data.cv_results: | ||
| lines.append("## Cross-Validation Results") | ||
| lines.append("") | ||
|
|
||
| # Summary statistics | ||
| accuracies = [r.accuracy for r in data.cv_results] | ||
| f1_behavior = [r.f1_behavior for r in data.cv_results] | ||
|
|
||
| lines.append("### Performance Summary") | ||
| lines.append("") | ||
| lines.append( | ||
| f"- **Mean Accuracy:** {np.mean(accuracies):.4f} (± {np.std(accuracies):.4f})" | ||
| ) | ||
| lines.append( | ||
| f"- **Mean F1 Score (Behavior):** {np.mean(f1_behavior):.4f} (± {np.std(f1_behavior):.4f})" | ||
| ) | ||
| lines.append("") | ||
|
|
||
| # Detailed results table | ||
| lines.append("### Iteration Details") | ||
| lines.append("") | ||
|
|
||
| table_data = [] | ||
| for result in data.cv_results: | ||
| table_data.append( | ||
| [ | ||
| result.iteration, | ||
| f"{result.accuracy:.4f}", | ||
| f"{result.precision_not_behavior:.4f}", | ||
| f"{result.precision_behavior:.4f}", | ||
| f"{result.recall_not_behavior:.4f}", | ||
| f"{result.recall_behavior:.4f}", | ||
| f"{result.f1_behavior:.4f}", | ||
| f"{result.test_video} [{result.test_identity}]", | ||
| ] | ||
| ) | ||
|
|
||
| headers = [ | ||
| "Iter", | ||
| "Accuracy", | ||
| "Precision (Not Behavior)", | ||
| "Precision (Behavior)", | ||
| "Recall (Not Behavior)", | ||
| "Recall (Behavior)", | ||
| "F1 Score", | ||
| "Test Group (Video [Identity])", | ||
| ] | ||
|
|
||
| table_markdown = tabulate(table_data, headers=headers, tablefmt="github") | ||
| lines.append(table_markdown) | ||
| lines.append("") | ||
| else: | ||
| # No cross-validation was performed | ||
| lines.append("## Cross-Validation") | ||
| lines.append("") | ||
| lines.append("*No cross-validation was performed for this training.*") | ||
| lines.append("") | ||
|
|
||
| # Final model feature importance | ||
| lines.append("## Feature Importance") | ||
| lines.append("") | ||
| lines.append("Top 20 features from final model (trained on all labeled data):") | ||
| lines.append("") | ||
|
|
||
| feature_table = [] | ||
| for rank, (feature_name, importance) in enumerate(data.final_top_features, start=1): | ||
| feature_table.append([rank, feature_name, f"{importance:.2f}"]) | ||
|
|
||
| feature_markdown = tabulate( | ||
| feature_table, headers=["Rank", "Feature Name", "Importance"], tablefmt="github" | ||
| ) | ||
| lines.append(feature_markdown) | ||
| lines.append("") | ||
|
|
||
| return "\n".join(lines) | ||
|
|
||
|
|
||
| def markdown_to_html(markdown_text: str) -> str: | ||
| """Convert markdown text to HTML. | ||
|
|
||
| Args: | ||
| markdown_text: Markdown-formatted string | ||
|
|
||
| Returns: | ||
| HTML string with basic styling | ||
| """ | ||
| html_content = markdown2.markdown(markdown_text, extras=["tables", "fenced-code-blocks"]) | ||
|
|
||
| # Wrap in basic HTML document with styling | ||
| html = dedent(f""" | ||
| <!DOCTYPE html> | ||
| <html> | ||
| <head> | ||
| <meta charset="utf-8"> | ||
| <style> | ||
| body {{ | ||
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif; | ||
| line-height: 1.6; | ||
| max-width: 1200px; | ||
| margin: 20px auto; | ||
| padding: 0 20px; | ||
| color: #333; | ||
| }} | ||
| h1 {{ | ||
| border-bottom: 2px solid #333; | ||
| padding-bottom: 10px; | ||
| }} | ||
| h2 {{ | ||
| border-bottom: 1px solid #ccc; | ||
| padding-bottom: 8px; | ||
| margin-top: 30px; | ||
| }} | ||
| h3 {{ | ||
| margin-top: 20px; | ||
| }} | ||
| table {{ | ||
| border-collapse: collapse; | ||
| width: 100%; | ||
| margin: 20px 0; | ||
| }} | ||
| th, td {{ | ||
| border: 1px solid #ddd; | ||
| padding: 8px; | ||
| text-align: left; | ||
| }} | ||
| th {{ | ||
| background-color: #f2f2f2; | ||
| font-weight: bold; | ||
| }} | ||
| tr:nth-child(even) {{ | ||
| background-color: #f9f9f9; | ||
| }} | ||
| code {{ | ||
| background-color: #f4f4f4; | ||
| padding: 2px 4px; | ||
| border-radius: 3px; | ||
| }} | ||
| ul {{ | ||
| line-height: 1.8; | ||
| }} | ||
| </style> | ||
| </head> | ||
| <body> | ||
| {html_content} | ||
| </body> | ||
| </html> | ||
| """).strip() | ||
| return html | ||
|
|
||
|
|
||
| def save_training_report(data: TrainingReportData, output_path: Path) -> None: | ||
| """Generate and save a training report as markdown. | ||
|
|
||
| Args: | ||
| data: TrainingData object containing all training information | ||
| output_path: Path where the markdown file should be saved | ||
| """ | ||
| markdown_content = generate_markdown_report(data) | ||
| with open(output_path, "w", encoding="utf-8") as f: | ||
| f.write(markdown_content) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.