Skip to content

Commit 0209d0a

Browse files
authored
Merge pull request #285 from KumarLabJax/task/add-mobilenet-unet-model
Add implementation for mobilnet unet model to jabs-vision
2 parents f972660 + a0fb093 commit 0209d0a

File tree

30 files changed

+1307
-18
lines changed

30 files changed

+1307
-18
lines changed

.github/workflows/_run-tests-action.yml

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,18 @@ jobs:
3737
# Run tests - discovers all test directories automatically
3838
- name: Test with pytest
3939
run: |
40-
# Collect all test directories
41-
TEST_DIRS=""
42-
43-
# Root tests directory
40+
set -euo pipefail
41+
42+
# Root tests
4443
if [ -d "tests" ]; then
45-
TEST_DIRS="tests"
44+
echo "Running root tests"
45+
uv run pytest tests
4646
fi
47-
48-
# Package-level test directories
49-
for pkg_tests in packages/*/tests; do
50-
if [ -d "$pkg_tests" ]; then
51-
TEST_DIRS="$TEST_DIRS $pkg_tests"
47+
48+
# Package tests: run each with that package as a root on sys.path
49+
for pkg in packages/*; do
50+
if [ -d "$pkg/tests" ]; then
51+
echo "Running tests in $pkg/tests"
52+
uv run pytest "$pkg/tests"
5253
fi
5354
done
54-
55-
if [ -n "$TEST_DIRS" ]; then
56-
echo "Running tests in: $TEST_DIRS"
57-
uv run pytest $TEST_DIRS
58-
else
59-
echo "No test directories found"
60-
exit 1
61-
fi
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
title: JABS Vision Backbones
3+
---
4+
5+
::: jabs.vision.backbones
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# JABS Vision (`jabs-vision`)
2+
3+
This package handles raw video processing and deep learning training and inference.
4+
5+
## Overview
6+
7+
`jabs-vision` is responsible for converting raw video frames into pose estimation data.
8+
It houses the heavy machine learning frameworks and GPU-accelerated code.
9+
10+
## Responsibilities
11+
12+
- **Pose Estimation Inference**: Running deep learning models (PyTorch) on video frames
13+
to detect keypoints.
14+
- **Static Object Detection**
15+
- **Segmentation Masking**
16+
- **Identity Matching**: Tracking individual animals across frames using vectorized
17+
features.
18+
- **Video Processing**: Handling frame extraction and normalization for input to vision
19+
models.
20+
21+
## Key Components
22+
23+
- `jabs.vision.inference`: Wrappers for pose estimation model execution.
24+
- `jabs.vision.tracking`: Identity matching and re-identification logic.
25+
26+
## Dependencies
27+
28+
- `torch` / `torchvision`
29+
- `opencv-python-headless`
30+
- `jabs-io`, `jabs-core`
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
title: JABS Vision Modules
3+
---
4+
5+
::: jabs.vision.modules
6+
options:
7+
show_submodules: true
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""Backbone modules for jabs-vision models."""
2+
3+
from .timm import TimmBackbone, TimmBackboneConfig
4+
5+
__all__ = [
6+
"TimmBackbone",
7+
"TimmBackboneConfig",
8+
]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Timm backbone wrapper for feature extraction."""
2+
3+
from dataclasses import dataclass, field
4+
5+
import torch.nn as nn
6+
from torch import Tensor
7+
8+
try:
9+
import timm
10+
except ImportError:
11+
timm = None # type: ignore[assignment]
12+
13+
14+
@dataclass
15+
class TimmBackboneConfig:
16+
"""Configuration for TimmBackbone.
17+
18+
Attributes:
19+
name: Name of the timm model to use.
20+
pretrained: Whether to load pretrained weights.
21+
out_indices: Which feature stages to return (0=stem, 1-4=stages).
22+
"""
23+
24+
name: str = "mobilenetv3_large_100"
25+
pretrained: bool = True
26+
out_indices: tuple[int, ...] = field(default_factory=lambda: (0, 1, 2, 3, 4))
27+
28+
29+
class TimmBackbone(nn.Module):
30+
"""Wrapper around timm models for multi-scale feature extraction.
31+
32+
This module creates a feature extractor from any timm model that supports
33+
`features_only=True`, returning feature maps at multiple scales.
34+
35+
Example:
36+
``` py
37+
cfg = TimmBackboneConfig(name="mobilenetv3_small_100", pretrained=True)
38+
backbone = TimmBackbone(cfg)
39+
x = torch.randn(1, 3, 256, 256)
40+
features = backbone(x) # List of feature tensors
41+
print([f.shape for f in features])
42+
```
43+
"""
44+
45+
def __init__(self, cfg: TimmBackboneConfig) -> None:
46+
super().__init__()
47+
48+
if timm is None:
49+
raise ImportError(
50+
"timm is required for TimmBackbone. "
51+
"Install it with: pip install 'jabs-vision[timm]' or pip install timm"
52+
)
53+
54+
self.cfg = cfg
55+
56+
# Create feature extractor
57+
self.model = timm.create_model(
58+
cfg.name,
59+
pretrained=cfg.pretrained,
60+
features_only=True,
61+
out_indices=cfg.out_indices,
62+
)
63+
64+
# Get feature info for channels and strides
65+
self._channels = [info["num_chs"] for info in self.model.feature_info]
66+
self._strides = [info["reduction"] for info in self.model.feature_info]
67+
68+
@property
69+
def channels(self) -> list[int]:
70+
"""Number of channels at each feature level."""
71+
return self._channels
72+
73+
@property
74+
def strides(self) -> list[int]:
75+
"""Spatial reduction (stride) at each feature level."""
76+
return self._strides
77+
78+
def forward(self, x: Tensor) -> list[Tensor]:
79+
"""Extract multi-scale features.
80+
81+
Args:
82+
x: Input tensor of shape (B, C, H, W).
83+
84+
Returns:
85+
List of feature tensors, one per output index.
86+
"""
87+
return self.model(x)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Core components for jabs-vision."""
2+
3+
from .interfaces import (
4+
BaseVisionModel,
5+
OutputKeys,
6+
)
7+
from .registry import MODEL_REGISTRY, ModelRegistry
8+
9+
__all__ = [
10+
"MODEL_REGISTRY",
11+
"BaseVisionModel",
12+
"ModelRegistry",
13+
"OutputKeys",
14+
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Core interfaces and protocols for jabs-vision models."""
2+
3+
from abc import ABC, abstractmethod
4+
5+
import torch
6+
from torch import Tensor
7+
8+
# =============================================================================
9+
# Abstract Base Classes (for implementation guidance)
10+
# =============================================================================
11+
12+
13+
class BaseVisionModel(ABC, torch.nn.Module):
14+
"""Abstract base class for vision models.
15+
16+
Provides common functionality and enforces the interface contract.
17+
"""
18+
19+
@abstractmethod
20+
def forward(self, x: Tensor) -> dict[str, Tensor]:
21+
"""Forward pass returning named outputs."""
22+
pass
23+
24+
25+
# =============================================================================
26+
# Output Key Conventions
27+
# =============================================================================
28+
29+
30+
class OutputKeys:
31+
"""Standard keys for model output dictionaries.
32+
33+
Using these constants prevents typos and documents the contract.
34+
"""
35+
36+
# Keypoint outputs
37+
HEATMAPS = "heatmaps"
38+
COORDS = "coords"
39+
KEYPOINT_CONFIDENCE = "keypoint_confidence"
40+
CONFIDENCE_MAPS = "confidence_maps" # Learned confidence (not derived from heatmaps)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Model registry for jabs-vision.
2+
3+
Provides a centralized registry for model classes with decorator-based registration.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
from collections.abc import Callable
9+
from typing import TypeVar
10+
11+
T = TypeVar("T")
12+
13+
14+
class ModelRegistry:
15+
"""Registry for model classes.
16+
17+
Usage:
18+
@MODEL_REGISTRY.register("my_model")
19+
class MyModel:
20+
...
21+
22+
# Later:
23+
model_cls = MODEL_REGISTRY.get("my_model")
24+
"""
25+
26+
def __init__(self) -> None:
27+
self._registry: dict[str, type] = {}
28+
29+
def register(self, name: str) -> Callable[[T], T]:
30+
"""Decorator to register a model class.
31+
32+
Args:
33+
name: Unique name for the model.
34+
35+
Returns:
36+
Decorator function that registers the class.
37+
38+
Raises:
39+
ValueError: If a model with this name is already registered.
40+
"""
41+
42+
def _register(cls: T) -> T:
43+
if name in self._registry:
44+
raise ValueError(f"Model '{name}' is already registered.")
45+
self._registry[name] = cls # type: ignore[assignment]
46+
return cls
47+
48+
return _register
49+
50+
def get(self, name: str) -> type:
51+
"""Retrieve a model class from the registry.
52+
53+
Args:
54+
name: Name of the registered model.
55+
56+
Returns:
57+
The registered model class.
58+
59+
Raises:
60+
KeyError: If the model name is not found.
61+
"""
62+
if name not in self._registry:
63+
raise KeyError(
64+
f"Model '{name}' not found in registry. "
65+
f"Registered models: {list(self._registry.keys())}"
66+
)
67+
return self._registry[name]
68+
69+
def list_models(self) -> list[str]:
70+
"""List all registered model names.
71+
72+
Returns:
73+
Sorted list of registered model names.
74+
"""
75+
return sorted(self._registry.keys())
76+
77+
def __contains__(self, name: str) -> bool:
78+
"""Check if a model name is registered."""
79+
return name in self._registry
80+
81+
82+
MODEL_REGISTRY = ModelRegistry()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Model definitions for JabsVision."""

0 commit comments

Comments
 (0)