Skip to content

Commit be4227b

Browse files
committed
better separation of models
1 parent f4b5d5a commit be4227b

File tree

5 files changed

+116
-12
lines changed

5 files changed

+116
-12
lines changed

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pathlib import Path
1717

1818
from huggingface_hub import hf_hub_download
19+
from ruamel.yaml import YAML
1920
from ruamel.yaml.comments import CommentedBase
2021

2122
# just expand this list when adding new models:
@@ -40,20 +41,64 @@ def _get_dlclibrary_path():
4041
return os.path.split(importlib.util.find_spec("dlclibrary").origin)[0]
4142

4243

43-
def _load_model_names():
44+
def _load_pytorch_models() -> dict[str, dict[str, dict[str, str]]]:
45+
"""Load URLs and commit hashes for available models."""
46+
urls = Path(_get_dlclibrary_path()) / "dlcmodelzoo" / "modelzoo_urls_pytorch.yaml"
47+
with open(urls) as file:
48+
data = YAML(pure=True).load(file)
49+
50+
return data
51+
52+
53+
def _load_pytorch_dataset_models(dataset: str) -> dict[str, dict[str, str]]:
4454
"""Load URLs and commit hashes for available models."""
45-
from ruamel.yaml import YAML
55+
models = _load_pytorch_models()
56+
if not dataset in models:
57+
raise ValueError(
58+
f"Could not find any models for {dataset}. Models are available for "
59+
f"{list(models.keys())}"
60+
)
4661

62+
return models[dataset]
63+
64+
65+
def _load_model_names():
66+
"""Load URLs and commit hashes for available models."""
4767
fn = os.path.join(_get_dlclibrary_path(), "dlcmodelzoo", "modelzoo_urls.yaml")
4868
with open(fn) as file:
49-
return YAML().load(file)
69+
model_names = YAML().load(file)
70+
71+
# add PyTorch models
72+
for dataset, model_types in _load_pytorch_models().items():
73+
for model_type, models in model_types.items():
74+
for model, url in models.items():
75+
model_names[f"{dataset}_{model}"] = url
76+
77+
return model_names
5078

5179

5280
def parse_available_supermodels():
5381
libpath = _get_dlclibrary_path()
5482
json_path = os.path.join(libpath, "dlcmodelzoo", "superanimal_models.json")
5583
with open(json_path) as file:
56-
return json.load(file)
84+
super_animal_models = json.load(file)
85+
return super_animal_models
86+
87+
88+
def get_available_detectors(dataset: str) -> list[str]:
89+
"""
90+
Returns:
91+
The detectors available for the dataset.
92+
"""
93+
return list(_load_pytorch_dataset_models(dataset)["detectors"].keys())
94+
95+
96+
def get_available_models(dataset: str) -> list[str]:
97+
"""
98+
Returns:
99+
The pose models available for the dataset.
100+
"""
101+
return list(_load_pytorch_dataset_models(dataset)["pose_models"].keys())
57102

58103

59104
def _handle_downloaded_file(

dlclibrary/dlcmodelzoo/modelzoo_urls.yaml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,6 @@ superanimal_quadruped_hrnetw32:
3636
- mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/pose_model.pth
3737
- mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/detector.pt
3838

39-
# Updated format - split detectors and pose models
40-
superanimal_bird_resnet_50: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_resnet_50.pt
41-
superanimal_bird_ssdlite: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_ssdlite.pt
42-
superanimal_topviewmouse_hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/pose_model.pth
43-
superanimal_topviewmouse_fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/detector.pt
44-
superanimal_quadruped_hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/pose_model.pth
45-
superanimal_quadruped_fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/detector.pt
46-
4739
# DeepLabCut 2.X backwards compatibility
4840
superanimal_topviewmouse: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/DLC_ma_supertopview5k_resnet_50_iteration-0_shuffle-1.tar.gz
4941

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# DeepLabCut 3.0: SuperAnimal detectors and pose model URLS
2+
3+
superanimal_bird:
4+
detectors:
5+
ssdlite: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_ssdlite.pt
6+
pose_models:
7+
resnet_50: DeepLabCut/DeepLabCutModelZoo-SuperAnimal-Bird/superanimal_bird_resnet_50.pt
8+
9+
superanimal_topviewmouse:
10+
detectors:
11+
fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/detector.pt
12+
pose_models:
13+
hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse/pose_model.pth
14+
15+
superanimal_quadruped:
16+
detectors:
17+
fasterrcnn_resnet50_fpn_v2: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/detector.pt
18+
pose_models:
19+
hrnet_w32: mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped/pose_model.pth

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"dlclibrary",
3333
[
3434
"dlclibrary/dlcmodelzoo/modelzoo_urls.yaml",
35+
"dlclibrary/dlcmodelzoo/modelzoo_urls_pytorch.yaml",
3536
"dlclibrary/dlcmodelzoo/superanimal_models.json",
3637
"dlclibrary/dlcmodelzoo/superanimal_configs/superquadruped.yaml",
3738
"dlclibrary/dlcmodelzoo/superanimal_configs/supertopview.yaml",

tests/test_pytorch_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#
2+
# DeepLabCut Toolbox (deeplabcut.org)
3+
# © A. & M.W. Mathis Labs
4+
# https://github.com/DeepLabCut/DeepLabCut
5+
#
6+
# Please see AUTHORS for contributors.
7+
# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
8+
#
9+
# Licensed under GNU Lesser General Public License v3.0
10+
#
11+
import os
12+
import pytest
13+
14+
import dlclibrary
15+
import dlclibrary.dlcmodelzoo.modelzoo_download as modelzoo
16+
17+
18+
@pytest.mark.parametrize(
19+
"data",
20+
[
21+
("superanimal_bird", ["ssdlite"]),
22+
("superanimal_topviewmouse", ["fasterrcnn_resnet50_fpn_v2"]),
23+
("superanimal_quadruped", ["fasterrcnn_resnet50_fpn_v2"]),
24+
]
25+
)
26+
def test_get_super_animal_detectors(data: tuple[str, list[str]]):
27+
dataset, expected_detectors = data
28+
detectors = modelzoo.get_available_detectors(dataset)
29+
assert len(detectors) >= len(expected_detectors)
30+
for det in expected_detectors:
31+
assert det in detectors
32+
33+
34+
@pytest.mark.parametrize(
35+
"data",
36+
[
37+
("superanimal_bird", ["resnet_50"]),
38+
("superanimal_topviewmouse", ["hrnet_w32"]),
39+
("superanimal_quadruped", ["hrnet_w32"]),
40+
]
41+
)
42+
def test_get_super_animal_pose_models(data: tuple[str, list[str]]):
43+
dataset, expected_pose_models = data
44+
pose_models = modelzoo.get_available_models(dataset)
45+
assert len(pose_models) >= len(expected_pose_models)
46+
for pose_model in expected_pose_models:
47+
assert pose_model in pose_models

0 commit comments

Comments
 (0)