Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 41 additions & 41 deletions src/jabs/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
confusion_matrix,
precision_recall_fscore_support,
)
from sklearn.model_selection import LeaveOneGroupOut, train_test_split
from sklearn.model_selection import LeaveOneGroupOut

from jabs.project import Project, TrackLabels, load_training_data
from jabs.types import ClassifierType
from jabs.types import (
DEFAULT_CV_GROUPING_STRATEGY,
ClassifierType,
CrossValidationGroupingStrategy,
)
from jabs.utils import hash_file

_VERSION = 10
Expand Down Expand Up @@ -200,38 +204,6 @@ def feature_names(self) -> list[str] | None:
"""returns the list of feature names used when training this classifier"""
return self._feature_names

@staticmethod
def train_test_split(per_frame_features, window_features, label_data):
"""split features and labels into training and test datasets

Args:
per_frame_features: per frame features as returned from IdentityFeatures object, filtered to only include labeled frames
window_features: window features as returned from IdentityFeatures object, filtered to only include labeled frames
label_data: labels that correspond to the features

Returns:
dictionary of training and test data and labels:

{
'training_data': list of numpy arrays,
'test_data': list of numpy arrays,
'training_labels': numpy array,
'test_labels': numpy_array,
'feature_names': list of feature names
}
"""
# split labeled data and labels
all_features = pd.concat([per_frame_features, window_features], axis=1)
x_train, x_test, y_train, y_test = train_test_split(all_features, label_data)

return {
"training_data": x_train,
"training_labels": y_train,
"test_data": x_test,
"test_labels": y_test,
"feature_names": all_features.columns.to_list(),
}

@staticmethod
def get_leave_one_group_out_max(labels, groups):
"""counts the number of possible leave one out groups for k-fold cross validation
Expand Down Expand Up @@ -674,7 +646,10 @@ def print_feature_importance(self, limit=20):
print(f"{feature:100} {importance:0.2f}")

@staticmethod
def count_label_threshold(all_counts: dict):
def count_label_threshold(
all_counts: dict,
cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY,
) -> int:
"""counts the number of groups that meet label threshold criteria

Args:
Expand Down Expand Up @@ -705,23 +680,45 @@ def count_label_threshold(all_counts: dict):
}
}

cv_grouping_strategy: cross-validation grouping strategy

Returns:
number of groups that meet label criteria

Note: uses "fragmented" label counts, since these reflect the counts of labels that are usable for training
"""
group_count = 0
for video in all_counts:
for identity_count in all_counts[video].values():
if cv_grouping_strategy == CrossValidationGroupingStrategy.INDIVIDUAL:
for video in all_counts:
for identity_count in all_counts[video].values():
if (
identity_count["fragmented_frame_counts"][0] >= Classifier.LABEL_THRESHOLD
and identity_count["fragmented_frame_counts"][1]
>= Classifier.LABEL_THRESHOLD
):
group_count += 1
elif cv_grouping_strategy == CrossValidationGroupingStrategy.VIDEO:
for video in all_counts:
behavior_sum = 0
not_behavior_sum = 0
for identity_count in all_counts[video].values():
behavior_sum += identity_count["fragmented_frame_counts"][0]
not_behavior_sum += identity_count["fragmented_frame_counts"][1]
if (
identity_count["fragmented_frame_counts"][0] >= Classifier.LABEL_THRESHOLD
and identity_count["fragmented_frame_counts"][1] >= Classifier.LABEL_THRESHOLD
behavior_sum >= Classifier.LABEL_THRESHOLD
and not_behavior_sum >= Classifier.LABEL_THRESHOLD
):
group_count += 1
else:
raise ValueError(f"Unknown cv_grouping_strategy: {cv_grouping_strategy}")
return group_count

@staticmethod
def label_threshold_met(all_counts: dict, min_groups: int):
def label_threshold_met(
all_counts: dict,
min_groups: int,
cv_grouping_strategy: CrossValidationGroupingStrategy = DEFAULT_CV_GROUPING_STRATEGY,
) -> bool:
"""determine if the labeling threshold is met

Args:
Expand All @@ -730,11 +727,14 @@ def label_threshold_met(all_counts: dict, min_groups: int):
min_groups: minimum number of groups required (more than one
group is always required for the "leave one group out" train/test split,
but may be more than 2 for k-fold cross validation if k > 2)
cv_grouping_strategy: cross-validation grouping strategy

Returns:
bool if requested valid groups is > valid group
"""
group_count = Classifier.count_label_threshold(all_counts)
group_count = Classifier.count_label_threshold(
all_counts, cv_grouping_strategy=cv_grouping_strategy
)
return 1 < group_count >= min_groups

@staticmethod
Expand Down
17 changes: 10 additions & 7 deletions src/jabs/classifier/training_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
import numpy as np
from tabulate import tabulate

from jabs.types import CrossValidationGroupingStrategy


@dataclass
class CrossValidationResult:
"""Results from a single cross-validation iteration.

Attributes:
iteration: The iteration number (1-indexed)
test_video: Name of the video used for testing
test_identity: Identity label used for testing
test_label: Label of the test grouping (e.g., video filename and possibly identity index)
accuracy: Classification accuracy (0.0 to 1.0)
precision_behavior: Precision for behavior class
precision_not_behavior: Precision for not-behavior class
Expand All @@ -29,8 +30,7 @@ class CrossValidationResult:
"""

iteration: int
test_video: str
test_identity: str
test_label: str
accuracy: float
precision_behavior: float
precision_not_behavior: float
Expand Down Expand Up @@ -62,6 +62,7 @@ class TrainingReportData:
bouts_not_behavior: Total number of not-behavior bouts labeled
training_time_ms: Total training time in milliseconds
timestamp: Datetime when training was completed
cv_grouping_strategy: Strategy used for cross-validation grouping
"""

behavior_name: str
Expand All @@ -78,6 +79,7 @@ class TrainingReportData:
bouts_not_behavior: int
training_time_ms: int
timestamp: datetime
cv_grouping_strategy: CrossValidationGroupingStrategy


def _escape_markdown(text: str) -> str:
Expand Down Expand Up @@ -156,12 +158,13 @@ def generate_markdown_report(data: TrainingReportData) -> str:

# Detailed results table
lines.append("### Iteration Details")
lines.append(f"CV Grouping Strategy: {data.cv_grouping_strategy.value}")
lines.append("")

table_data = []
for result in data.cv_results:
# Escape markdown special characters in video
escaped_video = _escape_markdown(result.test_video)
escaped_video = _escape_markdown(result.test_label)

table_data.append(
[
Expand All @@ -172,7 +175,7 @@ def generate_markdown_report(data: TrainingReportData) -> str:
f"{result.recall_not_behavior:.4f}",
f"{result.recall_behavior:.4f}",
f"{result.f1_behavior:.4f}",
f"{escaped_video} [{result.test_identity}]",
f"{escaped_video}",
]
)

Expand All @@ -184,7 +187,7 @@ def generate_markdown_report(data: TrainingReportData) -> str:
"Recall (Not Behavior)",
"Recall (Behavior)",
"F1 Score",
"Test Group (Video [Identity])",
"Test Group",
]

table_markdown = tabulate(table_data, headers=headers, tablefmt="github")
Expand Down
3 changes: 3 additions & 0 deletions src/jabs/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@
# some defaults for compressing hdf5 output
COMPRESSION = "gzip"
COMPRESSION_OPTS_DEFAULT = 6

# settings keys
SETTINGS_CV_GROUPING = "cv_grouping"
50 changes: 36 additions & 14 deletions src/jabs/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import jabs.feature_extraction as fe
from jabs.pose_estimation import PoseEstimation, get_pose_path, open_pose_file
from jabs.types import ProjectDistanceUnit
from jabs.types import CrossValidationGroupingStrategy, ProjectDistanceUnit

from .feature_manager import FeatureManager
from .parallel_workers import FeatureLoadJobSpec, collect_labeled_features
Expand Down Expand Up @@ -590,6 +590,9 @@ def get_labeled_features(
behavior_settings = self._settings_manager.get_behavior(behavior)
videos = list(self._video_manager.videos)

# get the cross validation grouping strategy from project settings
grouping_strategy = self.settings_manager.cv_grouping_strategy

# Early exit if no videos
if not videos:
return {
Expand Down Expand Up @@ -680,19 +683,32 @@ def get_labeled_features(
"groups": np.array([], dtype=np.int32),
}, {}

# Build stable group ids: original video order, then identity order as observed
# Build stable group ids based on grouping strategy
key_to_gid: dict[tuple[str, int], int] = {}
video_to_gid: dict[str, int] = {}
gid = 0
for v in videos:
seen: list[int] = []
for video_name, ident in all_group_keys:
if video_name == v and ident not in seen:
seen.append(ident)
for ident in seen:
key = (v, ident)
if key not in key_to_gid:
key_to_gid[key] = gid
if grouping_strategy == CrossValidationGroupingStrategy.INDIVIDUAL:
for v in videos:
seen: list[int] = []
for video_name, ident in all_group_keys:
if video_name == v and ident not in seen:
seen.append(ident)
for ident in seen:
key = (v, ident)
if key not in key_to_gid:
key_to_gid[key] = gid
gid += 1
elif grouping_strategy == CrossValidationGroupingStrategy.VIDEO:
for v in videos:
if v not in video_to_gid:
video_to_gid[v] = gid
gid += 1
for video_name, ident in all_group_keys:
if video_name == v:
key = (v, ident)
key_to_gid[key] = video_to_gid[v]
else:
raise ValueError(f"Unknown grouping strategy: {grouping_strategy}")

# groups vector aligned with all_per_frame entries
groups_list: list[np.ndarray] = [
Expand All @@ -701,9 +717,15 @@ def get_labeled_features(
]
groups = np.concatenate(groups_list) if groups_list else np.array([], dtype=np.int32)

group_mapping: dict[int, dict[str, int | str]] = {
gid: {"video": v, "identity": ident} for (v, ident), gid in key_to_gid.items()
}
# group_mapping: for INDIVIDUAL, maps gid to (video, identity); for VIDEO, maps gid to video only
if grouping_strategy == CrossValidationGroupingStrategy.INDIVIDUAL:
group_mapping: dict[int, dict[str, int | str]] = {
gid: {"video": v, "identity": ident} for (v, ident), gid in key_to_gid.items()
}
else:
group_mapping: dict[int, dict[str, str | None]] = {
gid: {"video": v, "identity": None} for v, gid in video_to_gid.items()
}

window_df = pd.concat(all_window, join="inner")
per_frame_df = pd.concat(all_per_frame, join="inner")
Expand Down
17 changes: 17 additions & 0 deletions src/jabs/project/settings_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import typing

import jabs.feature_extraction as feature_extraction
from jabs.constants import SETTINGS_CV_GROUPING
from jabs.types.cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy
from jabs.version import version_str

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -83,6 +85,21 @@ def project_metadata(self) -> dict:
"""
return self._project_info.get("metadata", {})

@property
def cv_grouping_strategy(self) -> CrossValidationGroupingStrategy:
"""Get the cross-validation grouping strategy for the project.

Returns:
CrossValidationGroupingStrategy: The CV grouping strategy.
"""
grouping_str = self._project_info.get("settings", {}).get(
SETTINGS_CV_GROUPING, DEFAULT_CV_GROUPING_STRATEGY.value
)
try:
return CrossValidationGroupingStrategy(grouping_str)
except ValueError:
return DEFAULT_CV_GROUPING_STRATEGY

def video_metadata(self, video: str) -> dict:
"""Get metadata for a specific video.

Expand Down
18 changes: 18 additions & 0 deletions src/jabs/resources/docs/user_guide/gui.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,14 @@ Clicking the Brightness or Contrast controls will reset the brightness or contra
## Menu

- **JABS→About:** Display About Dialog
- **JABS→Project Settings:** Display Project Settings Dialog
- **JABS→User Guide:** Display User Guide
- **JABS→Check for Updates:** Check PyPI for JABS updates
- **JABS→View License Agreement:** Display License Agreement
- **JABS→Enable Session Tracking:** Enable labeling session tracking
- **JABS→Quit JABS:** Quit Program
- **File→Open Project:** Select a project directory to open. If a project is already opened, it will be closed and the newly selected project will be opened.
- **File→Open Recent:** Submenu to open recently opened projects.
- **File→Export Training Data:** Create a file with the information needed to share a classifier. This exported file is written to the project directory and has the form `<Behavior_Name>_training_<YYYYMMDD_hhmmss>.h5`. This file is used as one input for the `jabs-classify` script.
- **File→Archive Behavior:** Remove behavior and its labels from project. Labels are archived in the `jabs/archive` directory.
- **File→Prune Project:** Remove videos and pose files that are not labeled.
Expand All @@ -120,6 +125,19 @@ Clicking the Brightness or Contrast controls will reset the brightness or contra
- **Window→Bring All to Front:** Bring all JABS windows to the front
- The Window menu also displays a list of all open JABS windows (main window, user guide, training reports, etc.) with a checkmark (✓) next to the currently active window. Click any window in the list to activate and bring it to the front.

## Project Settings Dialog

The **Project Settings Dialog**, available from the JABS menu, allows you to configure project-wide settings. Settings are designed to be easily discoverable; each settings group includes built-in documentation, making it easy to understand the purpose and effect of each option directly within the dialog. Users should explore this dialog and customize project settings as needed.

### Settings Overview

| Setting | Description |
|-------------------------------|------------------------------------------------------------------|
| Cross Validation Grouping | Determines how cross-validation groups are defined. Options are "Individual Animal" (default) or "Video". |

As new settings are added, they will appear in this dialog with inline documentation.


## Overlays

### Track Overlay Example
Expand Down
8 changes: 8 additions & 0 deletions src/jabs/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
"""Module for defining enums used in JABS"""

from .classifier_types import ClassifierType
from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy
from .units import ProjectDistanceUnit

__all__ = [
"DEFAULT_CV_GROUPING_STRATEGY",
"ClassifierType",
"CrossValidationGroupingStrategy",
"ProjectDistanceUnit",
]
15 changes: 15 additions & 0 deletions src/jabs/types/cv_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import Enum


class CrossValidationGroupingStrategy(str, Enum):
"""Cross-validation grouping type for the project.

Inheriting from str allows for easy serialization to/from JSON (the enum will
automatically be serialized using the enum value).
"""

INDIVIDUAL = "Individual Animal"
VIDEO = "Video"


DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL
Loading