1616from pathlib import Path
1717
1818from huggingface_hub import hf_hub_download
19+ from ruamel .yaml import YAML
1920from ruamel .yaml .comments import CommentedBase
2021
2122# just expand this list when adding new models:
2728 "mouse_pupil_vclose" ,
2829 "horse_sideview" ,
2930 "full_macaque" ,
30- "superanimal_topviewmouse_dlcrnet" ,
31- "superanimal_quadruped_dlcrnet" ,
32- "superanimal_topviewmouse_hrnetw32" ,
33- "superanimal_quadruped_hrnetw32" ,
34- "superanimal_topviewmouse" , # DeepLabCut 2.X backwards compatibility
35- "superanimal_quadruped" , # DeepLabCut 2.X backwards compatibility
31+ "superanimal_bird" ,
32+ "superanimal_quadruped" ,
33+ "superanimal_topviewmouse" ,
3634]
3735
3836
@@ -43,20 +41,66 @@ def _get_dlclibrary_path():
4341 return os .path .split (importlib .util .find_spec ("dlclibrary" ).origin )[0 ]
4442
4543
46- def _load_model_names ():
44+ def _load_pytorch_models () -> dict [str , dict [str , dict [str , str ]]]:
45+ """Load URLs and commit hashes for available models."""
46+ urls = Path (_get_dlclibrary_path ()) / "dlcmodelzoo" / "modelzoo_urls_pytorch.yaml"
47+ with open (urls ) as file :
48+ data = YAML (pure = True ).load (file )
49+
50+ return data
51+
52+
53+ def _load_pytorch_dataset_models (dataset : str ) -> dict [str , dict [str , str ]]:
4754 """Load URLs and commit hashes for available models."""
48- from ruamel .yaml import YAML
55+ models = _load_pytorch_models ()
56+ if not dataset in models :
57+ raise ValueError (
58+ f"Could not find any models for { dataset } . Models are available for "
59+ f"{ list (models .keys ())} "
60+ )
4961
62+ return models [dataset ]
63+
64+
65+ def _load_model_names ():
66+ """Load URLs and commit hashes for available models."""
5067 fn = os .path .join (_get_dlclibrary_path (), "dlcmodelzoo" , "modelzoo_urls.yaml" )
5168 with open (fn ) as file :
52- return YAML ().load (file )
69+ model_names = YAML ().load (file )
70+
71+ # add PyTorch models
72+ for dataset , model_types in _load_pytorch_models ().items ():
73+ for model_type , models in model_types .items ():
74+ for model , url in models .items ():
75+ model_names [f"{ dataset } _{ model } " ] = url
76+
77+ return model_names
5378
5479
5580def parse_available_supermodels ():
5681 libpath = _get_dlclibrary_path ()
5782 json_path = os .path .join (libpath , "dlcmodelzoo" , "superanimal_models.json" )
5883 with open (json_path ) as file :
59- return json .load (file )
84+ super_animal_models = json .load (file )
85+ return super_animal_models
86+
87+
88+ def get_available_detectors (dataset : str ) -> list [str ]:
89+ """ Only for PyTorch models.
90+
91+ Returns:
92+ The detectors available for the dataset.
93+ """
94+ return list (_load_pytorch_dataset_models (dataset )["detectors" ].keys ())
95+
96+
97+ def get_available_models (dataset : str ) -> list [str ]:
98+ """ Only for PyTorch models.
99+
100+ Returns:
101+ The pose models available for the dataset.
102+ """
103+ return list (_load_pytorch_dataset_models (dataset )["pose_models" ].keys ())
60104
61105
62106def _handle_downloaded_file (
@@ -103,7 +147,9 @@ def download_huggingface_model(
103147 """
104148 net_urls = _load_model_names ()
105149 if model_name not in net_urls :
106- raise ValueError (f"`modelname` should be one of: { ', ' .join (net_urls )} ." )
150+ raise ValueError (
151+ f"`modelname={ model_name } ` should be one of: { ', ' .join (net_urls )} ."
152+ )
107153
108154 print ("Loading...." , model_name )
109155 urls = net_urls [model_name ]
0 commit comments