diff --git a/src/jabs/classifier/classifier.py b/src/jabs/classifier/classifier.py index 43090c5f..2cc1ab53 100644 --- a/src/jabs/classifier/classifier.py +++ b/src/jabs/classifier/classifier.py @@ -17,10 +17,14 @@ confusion_matrix, precision_recall_fscore_support, ) -from sklearn.model_selection import LeaveOneGroupOut, train_test_split +from sklearn.model_selection import LeaveOneGroupOut from jabs.project import Project, TrackLabels, load_training_data -from jabs.types import ClassifierType +from jabs.types import ( + DEFAULT_CV_GROUPING_STRATEGY, + ClassifierType, + CrossValidationGroupingStrategy, +) from jabs.utils import hash_file _VERSION = 10 @@ -200,38 +204,6 @@ def feature_names(self) -> list[str] | None: """returns the list of feature names used when training this classifier""" return self._feature_names - @staticmethod - def train_test_split(per_frame_features, window_features, label_data): - """split features and labels into training and test datasets - - Args: - per_frame_features: per frame features as returned from IdentityFeatures object, filtered to only include labeled frames - window_features: window features as returned from IdentityFeatures object, filtered to only include labeled frames - label_data: labels that correspond to the features - - Returns: - dictionary of training and test data and labels: - - { - 'training_data': list of numpy arrays, - 'test_data': list of numpy arrays, - 'training_labels': numpy array, - 'test_labels': numpy_array, - 'feature_names': list of feature names - } - """ - # split labeled data and labels - all_features = pd.concat([per_frame_features, window_features], axis=1) - x_train, x_test, y_train, y_test = train_test_split(all_features, label_data) - - return { - "training_data": x_train, - "training_labels": y_train, - "test_data": x_test, - "test_labels": y_test, - "feature_names": all_features.columns.to_list(), - } - @staticmethod def get_leave_one_group_out_max(labels, groups): """counts the number of possible leave one out groups for k-fold cross validation @@ -674,7 +646,10 @@ def print_feature_importance(self, limit=20): print(f"{feature:100} {importance:0.2f}") @staticmethod - def count_label_threshold(all_counts: dict): + def count_label_threshold( + all_counts: dict, + cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + ) -> int: """counts the number of groups that meet label threshold criteria Args: @@ -705,23 +680,45 @@ def count_label_threshold(all_counts: dict): } } + cv_grouping_strategy: cross-validation grouping strategy + Returns: number of groups that meet label criteria Note: uses "fragmented" label counts, since these reflect the counts of labels that are usable for training """ group_count = 0 - for video in all_counts: - for identity_count in all_counts[video].values(): + if cv_grouping_strategy == CrossValidationGroupingStrategy.INDIVIDUAL: + for video in all_counts: + for identity_count in all_counts[video].values(): + if ( + identity_count["fragmented_frame_counts"][0] >= Classifier.LABEL_THRESHOLD + and identity_count["fragmented_frame_counts"][1] + >= Classifier.LABEL_THRESHOLD + ): + group_count += 1 + elif cv_grouping_strategy == CrossValidationGroupingStrategy.VIDEO: + for video in all_counts: + behavior_sum = 0 + not_behavior_sum = 0 + for identity_count in all_counts[video].values(): + behavior_sum += identity_count["fragmented_frame_counts"][0] + not_behavior_sum += identity_count["fragmented_frame_counts"][1] if ( - identity_count["fragmented_frame_counts"][0] >= Classifier.LABEL_THRESHOLD - and identity_count["fragmented_frame_counts"][1] >= Classifier.LABEL_THRESHOLD + behavior_sum >= Classifier.LABEL_THRESHOLD + and not_behavior_sum >= Classifier.LABEL_THRESHOLD ): group_count += 1 + else: + raise ValueError(f"Unknown cv_grouping_strategy: {cv_grouping_strategy}") return group_count @staticmethod - def label_threshold_met(all_counts: dict, min_groups: int): + def label_threshold_met( + all_counts: dict, + min_groups: int, + cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY, + ) -> bool: """determine if the labeling threshold is met Args: @@ -730,11 +727,14 @@ def label_threshold_met(all_counts: dict, min_groups: int): min_groups: minimum number of groups required (more than one group is always required for the "leave one group out" train/test split, but may be more than 2 for k-fold cross validation if k > 2) + cv_grouping_strategy: cross-validation grouping strategy Returns: bool if requested valid groups is > valid group """ - group_count = Classifier.count_label_threshold(all_counts) + group_count = Classifier.count_label_threshold( + all_counts, cv_grouping_strategy=cv_grouping_strategy + ) return 1 < group_count >= min_groups @staticmethod diff --git a/src/jabs/classifier/training_report.py b/src/jabs/classifier/training_report.py index d92ee48e..dafb7dbf 100644 --- a/src/jabs/classifier/training_report.py +++ b/src/jabs/classifier/training_report.py @@ -7,6 +7,8 @@ import numpy as np from tabulate import tabulate +from jabs.types import CrossValidationGroupingStrategy + @dataclass class CrossValidationResult: @@ -14,8 +16,7 @@ class CrossValidationResult: Attributes: iteration: The iteration number (1-indexed) - test_video: Name of the video used for testing - test_identity: Identity label used for testing + test_label: Label of the test grouping (e.g., video filename and possibly identity index) accuracy: Classification accuracy (0.0 to 1.0) precision_behavior: Precision for behavior class precision_not_behavior: Precision for not-behavior class @@ -29,8 +30,7 @@ class CrossValidationResult: """ iteration: int - test_video: str - test_identity: str + test_label: str accuracy: float precision_behavior: float precision_not_behavior: float @@ -62,6 +62,7 @@ class TrainingReportData: bouts_not_behavior: Total number of not-behavior bouts labeled training_time_ms: Total training time in milliseconds timestamp: Datetime when training was completed + cv_grouping_strategy: Strategy used for cross-validation grouping """ behavior_name: str @@ -78,6 +79,7 @@ class TrainingReportData: bouts_not_behavior: int training_time_ms: int timestamp: datetime + cv_grouping_strategy: CrossValidationGroupingStrategy def _escape_markdown(text: str) -> str: @@ -156,12 +158,13 @@ def generate_markdown_report(data: TrainingReportData) -> str: # Detailed results table lines.append("### Iteration Details") + lines.append(f"CV Grouping Strategy: {data.cv_grouping_strategy.value}") lines.append("") table_data = [] for result in data.cv_results: # Escape markdown special characters in video - escaped_video = _escape_markdown(result.test_video) + escaped_video = _escape_markdown(result.test_label) table_data.append( [ @@ -172,7 +175,7 @@ def generate_markdown_report(data: TrainingReportData) -> str: f"{result.recall_not_behavior:.4f}", f"{result.recall_behavior:.4f}", f"{result.f1_behavior:.4f}", - f"{escaped_video} [{result.test_identity}]", + f"{escaped_video}", ] ) @@ -184,7 +187,7 @@ def generate_markdown_report(data: TrainingReportData) -> str: "Recall (Not Behavior)", "Recall (Behavior)", "F1 Score", - "Test Group (Video [Identity])", + "Test Group", ] table_markdown = tabulate(table_data, headers=headers, tablefmt="github") diff --git a/src/jabs/constants.py b/src/jabs/constants.py index a4a585d0..d7e6c7a4 100644 --- a/src/jabs/constants.py +++ b/src/jabs/constants.py @@ -8,3 +8,6 @@ # some defaults for compressing hdf5 output COMPRESSION = "gzip" COMPRESSION_OPTS_DEFAULT = 6 + +# settings keys +SETTINGS_CV_GROUPING = "cv_grouping" diff --git a/src/jabs/project/project.py b/src/jabs/project/project.py index 164e3f37..379b019a 100644 --- a/src/jabs/project/project.py +++ b/src/jabs/project/project.py @@ -16,7 +16,7 @@ import jabs.feature_extraction as fe from jabs.pose_estimation import PoseEstimation, get_pose_path, open_pose_file -from jabs.types import ProjectDistanceUnit +from jabs.types import CrossValidationGroupingStrategy, ProjectDistanceUnit from .feature_manager import FeatureManager from .parallel_workers import FeatureLoadJobSpec, collect_labeled_features @@ -590,6 +590,9 @@ def get_labeled_features( behavior_settings = self._settings_manager.get_behavior(behavior) videos = list(self._video_manager.videos) + # get the cross validation grouping strategy from project settings + grouping_strategy = self.settings_manager.cv_grouping_strategy + # Early exit if no videos if not videos: return { @@ -680,19 +683,32 @@ def get_labeled_features( "groups": np.array([], dtype=np.int32), }, {} - # Build stable group ids: original video order, then identity order as observed + # Build stable group ids based on grouping strategy key_to_gid: dict[tuple[str, int], int] = {} + video_to_gid: dict[str, int] = {} gid = 0 - for v in videos: - seen: list[int] = [] - for video_name, ident in all_group_keys: - if video_name == v and ident not in seen: - seen.append(ident) - for ident in seen: - key = (v, ident) - if key not in key_to_gid: - key_to_gid[key] = gid + if grouping_strategy == CrossValidationGroupingStrategy.INDIVIDUAL: + for v in videos: + seen: list[int] = [] + for video_name, ident in all_group_keys: + if video_name == v and ident not in seen: + seen.append(ident) + for ident in seen: + key = (v, ident) + if key not in key_to_gid: + key_to_gid[key] = gid + gid += 1 + elif grouping_strategy == CrossValidationGroupingStrategy.VIDEO: + for v in videos: + if v not in video_to_gid: + video_to_gid[v] = gid gid += 1 + for video_name, ident in all_group_keys: + if video_name == v: + key = (v, ident) + key_to_gid[key] = video_to_gid[v] + else: + raise ValueError(f"Unknown grouping strategy: {grouping_strategy}") # groups vector aligned with all_per_frame entries groups_list: list[np.ndarray] = [ @@ -701,9 +717,15 @@ def get_labeled_features( ] groups = np.concatenate(groups_list) if groups_list else np.array([], dtype=np.int32) - group_mapping: dict[int, dict[str, int | str]] = { - gid: {"video": v, "identity": ident} for (v, ident), gid in key_to_gid.items() - } + # group_mapping: for INDIVIDUAL, maps gid to (video, identity); for VIDEO, maps gid to video only + if grouping_strategy == CrossValidationGroupingStrategy.INDIVIDUAL: + group_mapping: dict[int, dict[str, int | str]] = { + gid: {"video": v, "identity": ident} for (v, ident), gid in key_to_gid.items() + } + else: + group_mapping: dict[int, dict[str, str | None]] = { + gid: {"video": v, "identity": None} for v, gid in video_to_gid.items() + } window_df = pd.concat(all_window, join="inner") per_frame_df = pd.concat(all_per_frame, join="inner") diff --git a/src/jabs/project/settings_manager.py b/src/jabs/project/settings_manager.py index 213d0639..2e35b7a9 100644 --- a/src/jabs/project/settings_manager.py +++ b/src/jabs/project/settings_manager.py @@ -2,6 +2,8 @@ import typing import jabs.feature_extraction as feature_extraction +from jabs.constants import SETTINGS_CV_GROUPING +from jabs.types.cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy from jabs.version import version_str if typing.TYPE_CHECKING: @@ -83,6 +85,21 @@ def project_metadata(self) -> dict: """ return self._project_info.get("metadata", {}) + @property + def cv_grouping_strategy(self) -> CrossValidationGroupingStrategy: + """Get the cross-validation grouping strategy for the project. + + Returns: + CrossValidationGroupingStrategy: The CV grouping strategy. + """ + grouping_str = self._project_info.get("settings", {}).get( + SETTINGS_CV_GROUPING, DEFAULT_CV_GROUPING_STRATEGY.value + ) + try: + return CrossValidationGroupingStrategy(grouping_str) + except ValueError: + return DEFAULT_CV_GROUPING_STRATEGY + def video_metadata(self, video: str) -> dict: """Get metadata for a specific video. diff --git a/src/jabs/resources/docs/user_guide/gui.md b/src/jabs/resources/docs/user_guide/gui.md index ef5c0f65..22390a1b 100644 --- a/src/jabs/resources/docs/user_guide/gui.md +++ b/src/jabs/resources/docs/user_guide/gui.md @@ -91,9 +91,14 @@ Clicking the Brightness or Contrast controls will reset the brightness or contra ## Menu - **JABS→About:** Display About Dialog +- **JABS→Project Settings:** Display Project Settings Dialog - **JABS→User Guide:** Display User Guide +- **JABS→Check for Updates:** Check PyPI for JABS updates +- **JABS→View License Agreement:** Display License Agreement +- **JABS→Enable Session Tracking:** Enable labeling session tracking - **JABS→Quit JABS:** Quit Program - **File→Open Project:** Select a project directory to open. If a project is already opened, it will be closed and the newly selected project will be opened. +- **File→Open Recent:** Submenu to open recently opened projects. - **File→Export Training Data:** Create a file with the information needed to share a classifier. This exported file is written to the project directory and has the form `_training_.h5`. This file is used as one input for the `jabs-classify` script. - **File→Archive Behavior:** Remove behavior and its labels from project. Labels are archived in the `jabs/archive` directory. - **File→Prune Project:** Remove videos and pose files that are not labeled. @@ -120,6 +125,19 @@ Clicking the Brightness or Contrast controls will reset the brightness or contra - **Window→Bring All to Front:** Bring all JABS windows to the front - The Window menu also displays a list of all open JABS windows (main window, user guide, training reports, etc.) with a checkmark (✓) next to the currently active window. Click any window in the list to activate and bring it to the front. +## Project Settings Dialog + +The **Project Settings Dialog**, available from the JABS menu, allows you to configure project-wide settings. Settings are designed to be easily discoverable; each settings group includes built-in documentation, making it easy to understand the purpose and effect of each option directly within the dialog. Users should explore this dialog and customize project settings as needed. + +### Settings Overview + +| Setting | Description | +|-------------------------------|------------------------------------------------------------------| +| Cross Validation Grouping | Determines how cross-validation groups are defined. Options are "Individual Animal" (default) or "Video". | + +As new settings are added, they will appear in this dialog with inline documentation. + + ## Overlays ### Track Overlay Example diff --git a/src/jabs/types/__init__.py b/src/jabs/types/__init__.py index 37490e91..0e1fa972 100644 --- a/src/jabs/types/__init__.py +++ b/src/jabs/types/__init__.py @@ -1,4 +1,12 @@ """Module for defining enums used in JABS""" from .classifier_types import ClassifierType +from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy from .units import ProjectDistanceUnit + +__all__ = [ + "DEFAULT_CV_GROUPING_STRATEGY", + "ClassifierType", + "CrossValidationGroupingStrategy", + "ProjectDistanceUnit", +] diff --git a/src/jabs/types/cv_grouping.py b/src/jabs/types/cv_grouping.py new file mode 100644 index 00000000..c3ce9cb0 --- /dev/null +++ b/src/jabs/types/cv_grouping.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class CrossValidationGroupingStrategy(str, Enum): + """Cross-validation grouping type for the project. + + Inheriting from str allows for easy serialization to/from JSON (the enum will + automatically be serialized using the enum value). + """ + + INDIVIDUAL = "Individual Animal" + VIDEO = "Video" + + +DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL diff --git a/src/jabs/ui/central_widget.py b/src/jabs/ui/central_widget.py index c2712dbf..5c555d9c 100644 --- a/src/jabs/ui/central_widget.py +++ b/src/jabs/ui/central_widget.py @@ -108,7 +108,7 @@ def __init__(self, *args, **kwargs) -> None: self._controls.classify_clicked.connect(self._classify_button_clicked) self._controls.classifier_changed.connect(self._classifier_changed) self._controls.behavior_changed.connect(self._on_behavior_changed) - self._controls.kfold_changed.connect(self._set_train_button_enabled_state) + self._controls.kfold_changed.connect(self.set_train_button_enabled_state) self._controls.window_size_changed.connect(self._on_window_size_changed) self._controls.new_window_sizes.connect(self._save_window_sizes) self._controls.use_balance_labels_changed.connect(self._on_use_balance_labels_changed) @@ -509,7 +509,7 @@ def _on_behavior_changed(self) -> None: # display labels and predictions for new behavior self._set_label_track() - self._set_train_button_enabled_state() + self.set_train_button_enabled_state() self._project.settings_manager.save_project_file({"selected_behavior": self.behavior}) @@ -605,7 +605,7 @@ def _label_button_common(self) -> None: self._controls.disable_label_buttons() self._stacked_timeline.clear_selection() self._update_label_counts() - self._set_train_button_enabled_state() + self.set_train_button_enabled_state() self._player_widget.reload_frame() def _set_identities(self, identities: list[str]) -> None: @@ -982,7 +982,7 @@ def _get_prediction_list(self) -> tuple[list[np.ndarray], list[np.ndarray]]: probability_list.append(prediction_prob) return prediction_list, probability_list - def _set_train_button_enabled_state(self) -> None: + def set_train_button_enabled_state(self) -> None: """set the enabled property of the train button Sets enabled state of the train button to True or False depending on @@ -999,7 +999,11 @@ def _set_train_button_enabled_state(self) -> None: if self._project is None: return - if Classifier.label_threshold_met(self._counts, self._controls.kfold_value): + if Classifier.label_threshold_met( + self._counts, + self._controls.kfold_value, + self._project.settings_manager.cv_grouping_strategy, + ): self._controls.train_button_enabled = True self.export_training_status_change.emit(True) else: diff --git a/src/jabs/ui/main_window/main_window.py b/src/jabs/ui/main_window/main_window.py index d347ed09..84c3c2d7 100644 --- a/src/jabs/ui/main_window/main_window.py +++ b/src/jabs/ui/main_window/main_window.py @@ -132,8 +132,8 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None: self.menu_handlers = MenuHandlers(self) # Build all menus using MenuBuilder - menu_builder = MenuBuilder(self, app_name, app_name_long) - menu_refs = menu_builder.build_menus() + self.menu_builder = MenuBuilder(self, app_name, app_name_long) + menu_refs = self.menu_builder.build_menus() # Store references to menus and actions for later use self._window_menu = menu_refs.window_menu @@ -166,6 +166,7 @@ def __init__(self, app_name: str, app_name_long: str, *args, **kwargs) -> None: self.enable_social_features = menu_refs.enable_social_features self.enable_landmark_features = menu_refs.enable_landmark_features self.enable_segmentation_features = menu_refs.enable_segmentation_features + self._settings_action = menu_refs.settings_action # Update recent projects menu self._update_recent_projects() @@ -379,6 +380,7 @@ def _project_loaded_callback(self) -> None: # Update which controls should be available self._archive_behavior.setEnabled(True) self._prune_action.setEnabled(True) + self._settings_action.setEnabled(True) self.enable_cm_units.setEnabled(self._project.feature_manager.is_cm_unit) self.enable_social_features.setEnabled( self._project.feature_manager.can_use_social_features @@ -517,3 +519,11 @@ def closeEvent(self, event: QtGui.QCloseEvent) -> None: # This allows for quick exit without hanging self._process_pool.shutdown(wait=False, cancel_futures=True) super().closeEvent(event) + + def on_settings_changed(self): + """Slot called when project settings are changed via SettingsDialog. + + Called when settings are changed, in case any UI updates are needed. + """ + # changing the settings can affect training thresholds, so the train button state needs to be updated + self._central_widget.set_train_button_enabled_state() diff --git a/src/jabs/ui/main_window/menu_builder.py b/src/jabs/ui/main_window/menu_builder.py index 6f7b818e..a3fe0b35 100644 --- a/src/jabs/ui/main_window/menu_builder.py +++ b/src/jabs/ui/main_window/menu_builder.py @@ -33,11 +33,14 @@ class MenuReferences: window_menu: QtWidgets.QMenu open_recent_menu: QtWidgets.QMenu + # App menu actions + settings_action: QtGui.QAction + clear_cache: QtGui.QAction + # File menu actions export_training: QtGui.QAction archive_behavior: QtGui.QAction prune_action: QtGui.QAction - clear_cache: QtGui.QAction # View menu actions view_playlist: QtGui.QAction @@ -107,7 +110,7 @@ def build_menus(self) -> MenuReferences: window_menu = menu_bar.addMenu("Window") # Build each menu - self._build_app_menu(app_menu) + app_actions = self._build_app_menu(app_menu) file_actions = self._build_file_menu(file_menu) view_actions = self._build_view_menu(view_menu) feature_actions = self._build_feature_menu(feature_menu) @@ -121,13 +124,13 @@ def build_menus(self) -> MenuReferences: view_menu=view_menu, feature_menu=feature_menu, window_menu=window_menu, - clear_cache=self._clear_cache, + **app_actions, **file_actions, **view_actions, **feature_actions, ) - def _build_app_menu(self, menu: QtWidgets.QMenu) -> None: + def _build_app_menu(self, menu: QtWidgets.QMenu) -> dict: """Build the application menu (About, User Guide, Quit, etc.). Args: @@ -139,6 +142,13 @@ def _build_app_menu(self, menu: QtWidgets.QMenu) -> None: about_action.triggered.connect(self.handlers.show_about_dialog) menu.addAction(about_action) + # Settings action + settings_action = QtGui.QAction(" &Project Settings", self.main_window) + settings_action.setStatusTip("Open Project Settings") + settings_action.triggered.connect(self.handlers.open_project_settings_dialog) + settings_action.setEnabled(False) # Disabled by default, enabled when project is loaded + menu.addAction(settings_action) + # User guide action user_guide_action = QtGui.QAction(" &User Guide", self.main_window) user_guide_action.setStatusTip("Open User Guide") @@ -169,11 +179,11 @@ def _build_app_menu(self, menu: QtWidgets.QMenu) -> None: menu.addAction(session_tracking_action) # Clear cache action (store as instance variable for later reference) - self._clear_cache = QtGui.QAction("Clear Project Cache", self.main_window) - self._clear_cache.setStatusTip("Clear Project Cache") - self._clear_cache.setEnabled(False) - self._clear_cache.triggered.connect(self.handlers.clear_cache) - menu.addAction(self._clear_cache) + clear_cache = QtGui.QAction("Clear Project Cache", self.main_window) + clear_cache.setStatusTip("Clear Project Cache") + clear_cache.setEnabled(False) + clear_cache.triggered.connect(self.handlers.clear_cache) + menu.addAction(clear_cache) # Quit action exit_action = QtGui.QAction(f" &Quit {self.app_name}", self.main_window) @@ -182,6 +192,11 @@ def _build_app_menu(self, menu: QtWidgets.QMenu) -> None: exit_action.triggered.connect(QtCore.QCoreApplication.quit) menu.addAction(exit_action) + return { + "clear_cache": clear_cache, + "settings_action": settings_action, + } + def _build_file_menu(self, menu: QtWidgets.QMenu) -> dict: """Build the File menu. diff --git a/src/jabs/ui/main_window/menu_handlers.py b/src/jabs/ui/main_window/menu_handlers.py index 45690414..a91b6309 100644 --- a/src/jabs/ui/main_window/menu_handlers.py +++ b/src/jabs/ui/main_window/menu_handlers.py @@ -22,6 +22,7 @@ from ..license_dialog import LicenseAgreementDialog from ..player_widget import PlayerWidget from ..project_pruning_dialog import ProjectPruningDialog +from ..settings_dialog import SettingsDialog from ..stacked_timeline_widget import StackedTimelineWidget from ..update_check_dialog import UpdateCheckDialog from ..user_guide_dialog import UserGuideDialog @@ -202,6 +203,12 @@ def show_about_dialog(self) -> None: ) about_dialog.exec() + def open_project_settings_dialog(self) -> None: + """Open the project settings dialog.""" + settings_dialog = SettingsDialog(self.window._project.settings_manager, self.window) + settings_dialog.settings_changed.connect(self.window.on_settings_changed) + settings_dialog.exec() + def open_user_guide(self) -> None: """Show the user guide document in a separate window.""" if self.window._user_guide_window is None: diff --git a/src/jabs/ui/settings_dialog/README.md b/src/jabs/ui/settings_dialog/README.md new file mode 100644 index 00000000..30401233 --- /dev/null +++ b/src/jabs/ui/settings_dialog/README.md @@ -0,0 +1,164 @@ +# Settings Dialog Package + +This package contains the JABS project settings dialog and related components. + +## Overview + +The settings dialog is built using a modular architecture where different groups of related settings can be added as separate `SettingsGroup` subclasses. Each group has: + +- A form-style grid layout for controls +- An optional collapsible documentation section +- Methods to get/set values from the project settings + +## Architecture + +### Key Components + +- **`SettingsDialog`** - Main dialog that hosts all settings groups in a scrollable area +- **`SettingsGroup`** - Base class for creating settings groups with controls and documentation +- **`CollapsibleSection`** - Reusable widget for collapsible help/documentation sections + +### How It Works + +1. The `SettingsDialog` creates a scrollable page that hosts multiple `SettingsGroup` instances +2. Each `SettingsGroup` manages its own controls and documentation +3. When the dialog opens, it loads current values from the project's `SettingsManager` +4. When the user clicks "Save", all groups' values are collected and saved to the project + +## Creating a New Settings Group + +To add a new group of settings: + +### 1. Create a new class inheriting from `SettingsGroup` + +```python +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QCheckBox, QComboBox, QLabel, QSpinBox, QSizePolicy +from .settings_group import SettingsGroup + + +class MySettingsGroup(SettingsGroup): + """Settings group for my feature.""" + + def __init__(self, parent=None): + super().__init__("My Feature Settings", parent) +``` + +### 2. Override `_create_controls()` to add your widgets + +Use `add_control_row()` to add labeled controls: + +```python +def _create_controls(self) -> None: + """Create the settings controls.""" + # Add a checkbox + self._enable_feature = QCheckBox("Enable this feature") + self.add_control_row("Enable feature:", self._enable_feature) + + # Add a combo box + self._method_selection = QComboBox() + self._method_selection.addItems(["Method A", "Method B", "Method C"]) + self.add_control_row("Method:", self._method_selection) + + # Add a spin box + self._iterations = QSpinBox() + self._iterations.setRange(1, 100) + self.add_control_row("Iterations:", self._iterations) +``` + +Or use `add_widget_row()` for widgets that don't fit the label/control pattern: + +```python +def _create_controls(self) -> None: + # Add a full-width checkbox (no label) + self._advanced_mode = QCheckBox("Enable advanced mode") + self.add_widget_row(self._advanced_mode) +``` + +### 3. Override `_create_documentation()` to add help text (optional) + +```python +def _create_documentation(self): + """Create help documentation for these settings.""" + help_label = QLabel(self) + help_label.setTextFormat(Qt.TextFormat.RichText) + help_label.setWordWrap(True) + help_label.setText( + """ +

What do these settings do?

+

Detailed explanation of what these settings control...

+ + + """ + ) + help_label.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + return help_label +``` + +### 4. Implement `get_values()` and `set_values()` + +```python +def get_values(self) -> dict: + """Get current settings values.""" + return { + "my_enable_feature": self._enable_feature.isChecked(), + "my_method": self._method_selection.currentText(), + "my_iterations": self._iterations.value(), + } + +def set_values(self, values: dict) -> None: + """Set settings values from a dictionary.""" + self._enable_feature.setChecked(values.get("my_enable_feature", False)) + + method = values.get("my_method", "Method A") + index = self._method_selection.findText(method) + if index >= 0: + self._method_selection.setCurrentIndex(index) + + self._iterations.setValue(values.get("my_iterations", 10)) +``` + +### 5. Add your group to the SettingsDialog + +Edit `settings_dialog.py`: + +```python +from .my_settings_group import MySettingsGroup + +class SettingsDialog(QDialog): + def __init__(self, settings_manager: SettingsManager, parent: QWidget | None = None): + # ... existing code ... + + # Add your settings group + my_group = MySettingsGroup(page) + self._settings_groups.append(my_group) + page_layout.addWidget(my_group) + page_layout.setAlignment(my_group, Qt.AlignmentFlag.AlignTop) +``` + +## Tips + +- **Setting names**: Use descriptive names and consider prefixing with the group name to avoid conflicts (e.g., `"calibration_method"` instead of just `"method"`) +- **Default values**: Always provide sensible defaults in `set_values()` using `.get(key, default)` +- **Widget sizing**: Controls are kept compact by default. The third column of the grid expands to fill extra space, keeping controls left-aligned +- **Documentation**: Rich text is supported in help sections. Use `

`, `

`, `

    `, `
  • `, ``, etc. +- **Validation**: Add validation in `get_values()` or connect to widget signals if needed +- **Persistence**: Setting values are persisted via the `SettingsManager` when the dialog is saved. They will appear in the jabs/project.json file under the top-level "settings" key. + +## File Organization + +``` +settings_dialog/ +├── __init__.py # Exports SettingsDialog and SettingsGroup +├── settings_dialog.py # Main dialog +├── settings_group.py # Base class for settings groups +├── collapsible_section.py # Collapsible widget for documentation +└── README.md # This file +``` + +Add new settings groups as separate files in this directory, then import and instantiate them in `settings_dialog.py`. + diff --git a/src/jabs/ui/settings_dialog/__init__.py b/src/jabs/ui/settings_dialog/__init__.py new file mode 100644 index 00000000..3c72969c --- /dev/null +++ b/src/jabs/ui/settings_dialog/__init__.py @@ -0,0 +1,5 @@ +"""Settings dialog UI module.""" + +from .settings_dialog import SettingsDialog + +__all__ = ["SettingsDialog"] diff --git a/src/jabs/ui/settings_dialog/collapsible_section.py b/src/jabs/ui/settings_dialog/collapsible_section.py new file mode 100644 index 00000000..4f4921ab --- /dev/null +++ b/src/jabs/ui/settings_dialog/collapsible_section.py @@ -0,0 +1,81 @@ +from PySide6.QtCore import Signal +from PySide6.QtGui import Qt +from PySide6.QtWidgets import QFrame, QSizePolicy, QToolButton, QVBoxLayout, QWidget + + +class CollapsibleSection(QWidget): + """A collapsible section with a header ToolButton and a content area. + + This widget is used by SettingsGroup to provide inline documentation that is collapsed (hidden) by default. + """ + + sizeChanged = Signal() + toggled = Signal(bool) # Emitted when the section is expanded/collapsed + + def __init__(self, title: str, content: QWidget, parent: QWidget | None = None) -> None: + super().__init__(parent) + self._content = content + self._toggle_btn = QToolButton(self) + self._toggle_btn.setStyleSheet("QToolButton { border: none; }") + self._toggle_btn.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonTextBesideIcon) + self._toggle_btn.setArrowType(Qt.ArrowType.RightArrow) + self._toggle_btn.setText(title) + self._toggle_btn.setCheckable(True) + self._toggle_btn.setChecked(False) + self._toggle_btn.toggled.connect(self._on_toggled) + + line = QFrame(self) + line.setFrameShape(QFrame.Shape.HLine) + line.setFrameShadow(QFrame.Shadow.Sunken) + + self._content.setVisible(False) + # Ensure the collapsible widget and its content expand to fit content + self.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + self._content.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + + lay = QVBoxLayout(self) + lay.setContentsMargins(0, 0, 0, 0) + lay.addWidget(self._toggle_btn) + lay.addWidget(line) + lay.addWidget(self._content) + + def _on_toggled(self, checked: bool) -> None: + """Handle toggling the collapsible section.""" + self._toggle_btn.setArrowType( + Qt.ArrowType.DownArrow if checked else Qt.ArrowType.RightArrow + ) + self._content.setVisible(checked) + self._content.updateGeometry() + + # Ask ancestors to recompute layout so the page grows inside the scroll area + parent = self.parentWidget() + if parent is not None and parent.layout() is not None: + parent.layout().activate() + + if self.layout() is not None: + self.layout().activate() + + # Let ancestors recompute size hints and notify listeners + if parent is not None: + parent.updateGeometry() + self.updateGeometry() + self.sizeChanged.emit() + self.toggled.emit(checked) + + def is_expanded(self) -> bool: + """ + Check if the section is currently expanded. + + Returns: + True if the section is expanded, False otherwise. + """ + return self._toggle_btn.isChecked() + + def set_expanded(self, expanded: bool) -> None: + """ + Set the expanded state of the section. + + Args: + expanded: True to expand the section, False to collapse it. + """ + self._toggle_btn.setChecked(expanded) diff --git a/src/jabs/ui/settings_dialog/cross_validation_settings_group.py b/src/jabs/ui/settings_dialog/cross_validation_settings_group.py new file mode 100644 index 00000000..ca71c305 --- /dev/null +++ b/src/jabs/ui/settings_dialog/cross_validation_settings_group.py @@ -0,0 +1,104 @@ +"""Cross-validation settings group for configuring model training and validation.""" + +from PySide6.QtCore import Qt +from PySide6.QtWidgets import QComboBox, QLabel, QSizePolicy + +from jabs.constants import SETTINGS_CV_GROUPING +from jabs.types import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy + +from .settings_group import SettingsGroup + + +class CrossValidationSettingsGroup(SettingsGroup): + """ + Settings group for cross-validation configuration. + + This group controls how data is split during model training and validation. + """ + + def __init__(self, parent=None): + """Initialize the cross-validation settings group.""" + super().__init__("Cross-Validation", parent) + + def _create_controls(self) -> None: + """Create the cross-validation settings controls.""" + # Cross-validation grouping combo box + self._cv_grouping_combo = QComboBox() + # Add enum values as items, storing the enum as userData + for enum_val in CrossValidationGroupingStrategy: + self._cv_grouping_combo.addItem(enum_val.value, enum_val) + self._cv_grouping_combo.setCurrentIndex( + self._cv_grouping_combo.findData(DEFAULT_CV_GROUPING_STRATEGY) + ) + self._cv_grouping_combo.setSizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToContents) + + self.add_control_row("CV Grouping:", self._cv_grouping_combo) + + def _create_documentation(self): + """Create help documentation for cross-validation settings.""" + help_label = QLabel(self) + help_label.setTextFormat(Qt.TextFormat.RichText) + help_label.setWordWrap(True) + help_label.setText( + """ +

    What is Cross-Validation Grouping?

    +

    Cross-validation grouping determines how training data is split when + evaluating model performance using leave-one-group-out cross-validation.

    + +
      +
    • Individual Animal: Each group represents a single animal identity + within a single video. During cross-validation, all labeled data for one + animal from one video is held out for validation while the remaining animals' + data is used for training.
    • + +
    • Video: Each group represents a single video recording. During + cross-validation, all labeled data from one video is held out for validation + while data from other videos is used for training.
    • +
    + +

    Note: For cross-validation to work properly, you need labeled data + from multiple groups (multiple animals or multiple videos, depending on the + grouping method selected). For rare behaviors, it may be easier to meet the + minimum label requirements per group at the video level rather than at the + individual animal level within a single video.

    + """ + ) + help_label.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + return help_label + + def get_values(self) -> dict: + """ + Get current cross-validation settings values. + + Returns: + Dictionary with setting names and their current values. + """ + # Return the enum, not just the string + return { + SETTINGS_CV_GROUPING: self._cv_grouping_combo.currentData(), + } + + def set_values(self, values: dict) -> None: + """ + Set cross-validation settings values from a dictionary. + + Args: + values: Dictionary with setting names and values to apply. + """ + cv_grouping = values.get(SETTINGS_CV_GROUPING, CrossValidationGroupingStrategy.INDIVIDUAL) + + # The grouping setting is saved as the string value, try to convert string from settings dict back to enum + try: + enum_val = CrossValidationGroupingStrategy(cv_grouping) + index = self._cv_grouping_combo.findData(enum_val) + except ValueError: + # Invalid setting value, we'll treat it as not found + index = -1 + + if index >= 0: + self._cv_grouping_combo.setCurrentIndex(index) + else: + # Fall back to default if invalid value or not found + self._cv_grouping_combo.setCurrentIndex( + self._cv_grouping_combo.findData(CrossValidationGroupingStrategy.INDIVIDUAL) + ) diff --git a/src/jabs/ui/settings_dialog/settings_dialog.py b/src/jabs/ui/settings_dialog/settings_dialog.py new file mode 100644 index 00000000..7e63c70d --- /dev/null +++ b/src/jabs/ui/settings_dialog/settings_dialog.py @@ -0,0 +1,128 @@ +from PySide6.QtCore import Qt, Signal +from PySide6.QtGui import QResizeEvent, QShowEvent +from PySide6.QtWidgets import ( + QDialog, + QDialogButtonBox, + QFrame, + QScrollArea, + QSizePolicy, + QVBoxLayout, + QWidget, +) + +from jabs.project.settings_manager import SettingsManager + +from .cross_validation_settings_group import CrossValidationSettingsGroup + + +class SettingsDialog(QDialog): + """ + Dialog for changing project settings. + + Args: + settings_manager (SettingsManager): Project settings manager used to load and save settings. + parent (QWidget | None, optional): Parent widget for this dialog. Defaults to None. + """ + + settings_changed = Signal() + + def __init__(self, settings_manager: SettingsManager, parent: QWidget | None = None) -> None: + super().__init__(parent) + self.setWindowTitle("Project Settings") + self._settings_manager = settings_manager + + # Allow resizing and show scrollbars if content overflows + self.setSizeGripEnabled(True) + + # Scrollable page to host settings sections + page = QWidget(self) + page.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + page_layout = QVBoxLayout(page) + page_layout.setContentsMargins(0, 0, 0, 0) + page_layout.setSpacing(10) + + # Track all settings groups + self._settings_groups: list = [] + + # Add settings groups here + cv_group = CrossValidationSettingsGroup(page) + self._settings_groups.append(cv_group) + page_layout.addWidget(cv_group) + page_layout.setAlignment(cv_group, Qt.AlignmentFlag.AlignTop) + + # Load current settings into groups + self._load_settings() + + page_layout.addStretch(1) + + scroll = QScrollArea(self) + scroll.setWidget(page) + scroll.setWidgetResizable(True) + scroll.setFrameShape(QFrame.Shape.NoFrame) + scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + scroll.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + + # Keep references for width syncing + self._scroll = scroll + self._page = page + + # Buttons + btn_box = QDialogButtonBox(self) + btn_save = btn_box.addButton("Save", QDialogButtonBox.ButtonRole.AcceptRole) + btn_close = btn_box.addButton("Close", QDialogButtonBox.ButtonRole.RejectRole) + btn_save.clicked.connect(self._on_save) + btn_close.clicked.connect(self.reject) + + # Main layout + main = QVBoxLayout(self) + main.addWidget(scroll, 1) + main.addWidget(btn_box) + + self.setLayout(main) + + # Size to content initially + self.adjustSize() + self.resize(max(self.width(), 600), max(self.height(), 500)) + + def showEvent(self, e: QShowEvent) -> None: + """Handle the show event. + + Ensures the settings page width is synchronized with the viewport when the dialog is first shown. + + Args: + e (QShowEvent): The Qt show event. + """ + super().showEvent(e) + + def resizeEvent(self, e: QResizeEvent) -> None: + """Handle the resize event. + + Ensures the settings page width matches the viewport width when the dialog is resized. + + Args: + e (QResizeEvent): The Qt resize event. + """ + super().resizeEvent(e) + + def _load_settings(self) -> None: + """Load current settings from the project into all settings groups.""" + all_project_data = self._settings_manager.project_settings + current_settings = all_project_data.get("settings", {}) + for group in self._settings_groups: + group.set_values(current_settings) + + def _on_save(self) -> None: + """Save settings from all groups to project and close dialog.""" + # Collect settings from all groups + all_settings = {} + for group in self._settings_groups: + all_settings.update(group.get_values()) + + # Save to project if there are any settings + if all_settings: + settings = {"settings": all_settings} + self._settings_manager.save_project_file(settings) + self.settings_changed.emit() + + self.accept() diff --git a/src/jabs/ui/settings_dialog/settings_group.py b/src/jabs/ui/settings_dialog/settings_group.py new file mode 100644 index 00000000..648ee744 --- /dev/null +++ b/src/jabs/ui/settings_dialog/settings_group.py @@ -0,0 +1,248 @@ +from PySide6.QtCore import Qt, QTimer +from PySide6.QtWidgets import ( + QGridLayout, + QGroupBox, + QLabel, + QScrollArea, + QSizePolicy, + QSpacerItem, + QVBoxLayout, + QWidget, +) + +from .collapsible_section import CollapsibleSection + + +class SettingsGroup(QGroupBox): + """ + A reusable settings group with a grid layout for controls and optional collapsible documentation. + + This class provides a structured layout for adding related settings controls in a grid format, + with an optional collapsible help/documentation section below the controls. + + The grid layout has three columns: + - Column 0: Labels (natural size, right-aligned) + - Column 1: Input widgets (natural size) + - Column 2: Spacer (expands to fill remaining width, keeping controls left-aligned) + + Subclasses should override `_create_controls()` to add their specific settings widgets + and optionally override `_create_documentation()` to provide help text. + + Example: + class MySettingsGroup(SettingsGroup): + def __init__(self, parent=None): + super().__init__("My Settings", parent) + + def _create_controls(self): + self._my_checkbox = QCheckBox() + self.add_control_row("Enable feature:", self._my_checkbox) + + def _create_documentation(self): + help_label = QLabel("This is help text...") + help_label.setWordWrap(True) + return help_label + + def get_values(self): + return {"my_setting": self._my_checkbox.isChecked()} + + def set_values(self, values): + self._my_checkbox.setChecked(values.get("my_setting", False)) + """ + + def __init__(self, title: str, parent: QWidget | None = None) -> None: + """ + Initialize the settings group. + + Args: + title: The title displayed in the group box header. + parent: Parent widget for this settings group. + """ + super().__init__(title, parent) + self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Preferred) + + # Main vertical layout for the group + self._main_layout = QVBoxLayout(self) + self._main_layout.setContentsMargins(12, 12, 12, 12) + self._main_layout.setSpacing(8) + + # Grid layout for form controls + self._form_widget = QWidget(self) + self._grid_layout = QGridLayout(self._form_widget) + self._grid_layout.setContentsMargins(0, 0, 0, 0) + self._grid_layout.setHorizontalSpacing(12) + self._grid_layout.setVerticalSpacing(8) + self._grid_layout.setColumnStretch(0, 0) # labels column: natural size + self._grid_layout.setColumnStretch(1, 0) # inputs column: natural size + self._grid_layout.setColumnStretch(2, 1) # consume extra width on the right + + self._main_layout.addWidget(self._form_widget) + + # Track current row for adding controls + self._current_row = 0 + + # Create controls (subclasses override this) + self._create_controls() + + # Add horizontal spacer in column 2 for all rows + if self._current_row > 0: + self._grid_layout.addItem( + QSpacerItem(0, 0, QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Minimum), + 0, + 2, + self._current_row, + 1, + ) + + # Create optional documentation section + doc_widget = self._create_documentation() + if doc_widget is not None: + self._help_section = CollapsibleSection("What do these do?", doc_widget, self) + self._help_section.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred + ) + doc_widget.setSizePolicy(QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred) + + # Connect size change signal for reflow + self._help_section.sizeChanged.connect(self._on_help_section_resized) + + # Connect toggle signal to scroll to section when expanded + self._help_section.toggled.connect(self._on_help_toggled) + + self._main_layout.addWidget(self._help_section) + else: + self._help_section = None + + self._main_layout.addStretch(0) + + def add_control_row( + self, + label_text: str, + widget: QWidget, + alignment: Qt.AlignmentFlag = Qt.AlignmentFlag.AlignLeft, + ) -> int: + """ + Add a control row to the grid layout. + + Args: + label_text: Text for the label in column 0. + widget: The input widget to add in column 1. + alignment: Alignment for the widget in column 1 (default: AlignLeft). + + Returns: + The row index where the control was added. + """ + label = QLabel(label_text, self) + self._grid_layout.addWidget(label, self._current_row, 0, Qt.AlignmentFlag.AlignRight) + self._grid_layout.addWidget(widget, self._current_row, 1, alignment) + row = self._current_row + self._current_row += 1 + return row + + def add_widget_row(self, widget: QWidget, column_span: int = 3) -> int: + """ + Add a widget that spans multiple columns. + + Useful for widgets that don't fit the label/control pattern. + + Args: + widget: The widget to add. + column_span: Number of columns to span (default: 3 for full width). + + Returns: + The row index where the widget was added. + """ + self._grid_layout.addWidget(widget, self._current_row, 0, 1, column_span) + row = self._current_row + self._current_row += 1 + return row + + def _create_controls(self) -> None: + """ + Create and add control widgets to the settings group. + + Subclasses should override this method to add their specific controls using + `add_control_row()` or `add_widget_row()`. + """ + pass + + def _create_documentation(self) -> QWidget | None: + """ + Create documentation/help content for this settings group. + + Subclasses should override this method to return a widget containing help text + or documentation. If None is returned, no collapsible documentation section is shown. + + Returns: + A widget containing documentation, or None if no documentation is needed. + """ + return None + + def _on_help_section_resized(self) -> None: + """Handle help section size changes to trigger parent layout updates. + + With the scroll area's setWidgetResizable(True), the layout system + automatically handles size changes without needing explicit adjustSize calls. + We just need to activate the layout to reflow content. + """ + # Just activate layouts without calling adjustSize to avoid shrinking + parent = self.parentWidget() + if parent is not None: + parent_layout = parent.layout() + if parent_layout is not None: + parent_layout.activate() + + def _on_help_toggled(self, checked: bool) -> None: + """ + Handle help section toggle. + + Args: + checked: True if the help section is expanded, False if collapsed. + """ + if checked: + self.scroll_to_help_section() + + def get_scroll_area(self): + """ + Find the parent QScrollArea if one exists. + + Returns: + The parent QScrollArea, or None if not found. + """ + parent = self.parentWidget() + while parent is not None: + if isinstance(parent, QScrollArea): + return parent + parent = parent.parentWidget() + return None + + def scroll_to_help_section(self) -> None: + """Scroll to make the help section visible in the parent scroll area.""" + if self._help_section is None: + return + + scroll = self.get_scroll_area() + if scroll is not None: + # Defer one tick so QScrollArea can recompute its scroll range correctly + QTimer.singleShot(0, lambda: scroll.ensureWidgetVisible(self._help_section)) + + def get_values(self) -> dict: + """ + Get the current values from this settings group. + + Subclasses should override this to return a dictionary of setting names to values. + + Returns: + Dictionary mapping setting names to their current values. + """ + return {} + + def set_values(self, values: dict) -> None: + """ + Set values in this settings group. + + Subclasses should override this to update their controls from the provided values. + + Args: + values: Dictionary mapping setting names to values. + """ + pass diff --git a/src/jabs/ui/training_thread.py b/src/jabs/ui/training_thread.py index ad48290b..d756d53d 100644 --- a/src/jabs/ui/training_thread.py +++ b/src/jabs/ui/training_thread.py @@ -168,8 +168,9 @@ def id_processed() -> None: cv_results.append( CrossValidationResult( iteration=i + 1, - test_video=test_info["video"], - test_identity=test_info["identity"], + test_label=f"{test_info['video']} [{test_info['identity']}]" + if test_info["identity"] is not None + else test_info["video"], accuracy=accuracy, precision_behavior=pr[0][1], precision_not_behavior=pr[0][0], @@ -243,6 +244,7 @@ def id_processed() -> None: training_time_ms=elapsed_ms, timestamp=report_timestamp, window_size=behavior_settings["window_size"], + cv_grouping_strategy=self._project.settings_manager.cv_grouping_strategy, ) # Save markdown report diff --git a/tests/classifier/test_classifier.py b/tests/classifier/test_classifier.py index 28603f60..ba817b43 100644 --- a/tests/classifier/test_classifier.py +++ b/tests/classifier/test_classifier.py @@ -8,7 +8,7 @@ from jabs.classifier.classifier import Classifier from jabs.project import TrackLabels -from jabs.types import ClassifierType +from jabs.types import ClassifierType, CrossValidationGroupingStrategy @pytest.fixture @@ -168,26 +168,6 @@ def test_feature_names_property(self): class TestDataSplitting: """Test data splitting methods.""" - def test_train_test_split(self, sample_features, sample_labels): - """Test train_test_split creates proper splits.""" - per_frame = sample_features[["feature_1", "feature_2"]] - window = sample_features[["feature_3"]] - - result = Classifier.train_test_split(per_frame, window, sample_labels) - - assert "training_data" in result - assert "test_data" in result - assert "training_labels" in result - assert "test_labels" in result - assert "feature_names" in result - - # Check that splits sum to original size - assert len(result["training_data"]) + len(result["test_data"]) == len(sample_features) - assert len(result["training_labels"]) + len(result["test_labels"]) == len(sample_labels) - - # Check feature names - assert result["feature_names"] == ["feature_1", "feature_2", "feature_3"] - def test_leave_one_group_out(self, sample_features, sample_labels, sample_groups): """Test leave_one_group_out splitting.""" per_frame = sample_features[["feature_1", "feature_2"]] @@ -706,11 +686,93 @@ def test_label_threshold_met(self): } } - # Two groups meet threshold, need at least 2 - assert Classifier.label_threshold_met(all_counts, min_groups=2) + # INDIVIDUAL: Two groups meet threshold, need at least 2 + assert Classifier.label_threshold_met( + all_counts, + min_groups=2, + cv_grouping_strategy=CrossValidationGroupingStrategy.INDIVIDUAL, + ) + # INDIVIDUAL: Two groups meet threshold, but need at least 3 + assert not Classifier.label_threshold_met( + all_counts, + min_groups=3, + cv_grouping_strategy=CrossValidationGroupingStrategy.INDIVIDUAL, + ) + + # VIDEO: Only one video, should fail for min_groups=1 because we can't split into a train and test set + assert not Classifier.label_threshold_met( + all_counts, min_groups=1, cv_grouping_strategy=CrossValidationGroupingStrategy.VIDEO + ) - # Two groups meet threshold, but need at least 3 - assert not Classifier.label_threshold_met(all_counts, min_groups=3) + # Add a second video for VIDEO grouping strategy + multi_video_counts = { + "video1.avi": { + 0: { + "fragmented_frame_counts": (25, 25), + "fragmented_bout_counts": (5, 5), + "unfragmented_frame_counts": (25, 25), + "unfragmented_bout_counts": (5, 5), + } + }, + "video2.avi": { + 0: { + "fragmented_frame_counts": (30, 30), + "fragmented_bout_counts": (6, 6), + "unfragmented_frame_counts": (30, 30), + "unfragmented_bout_counts": (6, 6), + } + }, + } + # VIDEO: Two videos, both meet threshold, min_groups=2 should pass + assert Classifier.label_threshold_met( + multi_video_counts, + min_groups=2, + cv_grouping_strategy=CrossValidationGroupingStrategy.VIDEO, + ) + # VIDEO: Two videos, min_groups=3 should fail + assert not Classifier.label_threshold_met( + multi_video_counts, + min_groups=3, + cv_grouping_strategy=CrossValidationGroupingStrategy.VIDEO, + ) + # VIDEO: One video below threshold, one above + multi_video_counts_below = { + "video1.avi": { + 0: { + "fragmented_frame_counts": (10, 10), + "fragmented_bout_counts": (2, 2), + "unfragmented_frame_counts": (10, 10), + "unfragmented_bout_counts": (2, 2), + } + }, + "video2.avi": { + 0: { + "fragmented_frame_counts": (30, 30), + "fragmented_bout_counts": (6, 6), + "unfragmented_frame_counts": (30, 30), + "unfragmented_bout_counts": (6, 6), + } + }, + "video3.avi": { + 0: { + "fragmented_frame_counts": (30, 30), + "fragmented_bout_counts": (6, 6), + "unfragmented_frame_counts": (30, 30), + "unfragmented_bout_counts": (6, 6), + } + }, + } + # video2.avi and video3 meet threshold, min_groups=2 should pass + assert Classifier.label_threshold_met( + multi_video_counts_below, + min_groups=2, + cv_grouping_strategy=CrossValidationGroupingStrategy.VIDEO, + ) + assert not Classifier.label_threshold_met( + multi_video_counts_below, + min_groups=3, + cv_grouping_strategy=CrossValidationGroupingStrategy.VIDEO, + ) class TestFromTrainingFile: diff --git a/tests/classifier/test_training_report.py b/tests/classifier/test_training_report.py index 4bd01b30..16e39f78 100644 --- a/tests/classifier/test_training_report.py +++ b/tests/classifier/test_training_report.py @@ -11,6 +11,7 @@ generate_markdown_report, save_training_report, ) +from jabs.types import CrossValidationGroupingStrategy @pytest.fixture @@ -23,31 +24,31 @@ def sample_cv_results(): return [ CrossValidationResult( iteration=1, + test_label="video_1.mp4 [0]", accuracy=0.9234, precision_not_behavior=0.9145, precision_behavior=0.9323, recall_not_behavior=0.9456, recall_behavior=0.9012, f1_behavior=0.9163, - test_video="video_1.mp4", - test_identity="0", support_behavior=150, support_not_behavior=200, confusion_matrix=np.array([[180, 20], [15, 135]]), + top_features=[("nose_speed", 0.16), ("ear_angle", 0.14)], ), CrossValidationResult( iteration=2, + test_label="video_2.mp4 [1]", accuracy=0.8912, precision_not_behavior=0.8823, precision_behavior=0.9001, recall_not_behavior=0.9134, recall_behavior=0.8690, f1_behavior=0.8842, - test_video="video_2.mp4", - test_identity="1", support_behavior=140, support_not_behavior=210, confusion_matrix=np.array([[192, 18], [18, 122]]), + top_features=[("nose_speed", 0.15), ("ear_angle", 0.13)], ), ] @@ -80,6 +81,7 @@ def sample_training_data(sample_cv_results): bouts_not_behavior=156, training_time_ms=12345, timestamp=datetime(2026, 1, 3, 14, 30, 45), + cv_grouping_strategy=CrossValidationGroupingStrategy.INDIVIDUAL, ) @@ -90,31 +92,31 @@ def test_create_cv_result(self): """Test creating a CrossValidationResult instance.""" result = CrossValidationResult( iteration=1, + test_label="test.mp4 [0]", accuracy=0.95, precision_not_behavior=0.94, precision_behavior=0.96, recall_not_behavior=0.97, recall_behavior=0.93, f1_behavior=0.945, - test_video="test.mp4", - test_identity="0", support_behavior=100, support_not_behavior=150, confusion_matrix=np.array([[140, 10], [7, 93]]), + top_features=[("feature1", 0.5), ("feature2", 0.3)], ) assert result.iteration == 1 + assert result.test_label == "test.mp4 [0]" assert result.accuracy == 0.95 assert result.precision_not_behavior == 0.94 assert result.precision_behavior == 0.96 assert result.recall_not_behavior == 0.97 assert result.recall_behavior == 0.93 assert result.f1_behavior == 0.945 - assert result.test_video == "test.mp4" - assert result.test_identity == "0" assert result.support_behavior == 100 assert result.support_not_behavior == 150 assert result.confusion_matrix.shape == (2, 2) + assert result.top_features == [("feature1", 0.5), ("feature2", 0.3)] class TestTrainingReportData: @@ -138,6 +140,7 @@ def test_create_training_data(self, sample_cv_results): bouts_not_behavior=80, training_time_ms=5000, timestamp=timestamp, + cv_grouping_strategy=CrossValidationGroupingStrategy.VIDEO, ) assert data.behavior_name == "Rearing" @@ -154,6 +157,7 @@ def test_create_training_data(self, sample_cv_results): assert data.bouts_not_behavior == 80 assert data.training_time_ms == 5000 assert data.timestamp == timestamp + assert data.cv_grouping_strategy == CrossValidationGroupingStrategy.VIDEO class TestGenerateMarkdownReport: @@ -217,11 +221,10 @@ def test_report_contains_cv_table(self, sample_training_data): assert "Recall (Not Behavior)" in report assert "Recall (Behavior)" in report assert "F1 Score" in report - assert "Test Group (Video [Identity])" in report + assert "Test Group" in report - # Check for data rows (note: underscores are escaped in markdown) - assert "video\\_1.mp4 [0]" in report - assert "video\\_2.mp4 [1]" in report + assert "video\\_1.mp4 \\[0\\]" in report + assert "video\\_2.mp4 \\[1\\]" in report assert "0.9234" in report # accuracy from iteration 1 def test_report_contains_feature_importance(self, sample_training_data): @@ -253,6 +256,7 @@ def test_report_without_cv_results(self, sample_training_data): bouts_not_behavior=20, training_time_ms=1000, timestamp=datetime.now(), + cv_grouping_strategy=CrossValidationGroupingStrategy.INDIVIDUAL, ) report = generate_markdown_report(data_no_cv) @@ -265,13 +269,10 @@ def test_report_without_cv_results(self, sample_training_data): def test_markdown_escaping_in_video_names(self, sample_training_data): """Test that special characters in video names are escaped.""" - # Modify CV results to have video name with underscores - sample_training_data.cv_results[0].test_video = "test_video_with_underscores.mp4" - + sample_training_data.cv_results[0].test_label = "test_video_with_underscores.mp4 [0]" report = generate_markdown_report(sample_training_data) - - # Underscores should be escaped - assert "test\\_video\\_with\\_underscores.mp4" in report + # Tabulate does not preserve markdown escapes, so check for escaped string + assert "test\\_video\\_with\\_underscores.mp4 \\[0\\]" in report class TestSaveTrainingReport: