1616from pathlib import Path
1717
1818from huggingface_hub import hf_hub_download
19+ from ruamel .yaml import YAML
1920from ruamel .yaml .comments import CommentedBase
2021
2122# just expand this list when adding new models:
@@ -40,20 +41,64 @@ def _get_dlclibrary_path():
4041 return os .path .split (importlib .util .find_spec ("dlclibrary" ).origin )[0 ]
4142
4243
43- def _load_model_names ():
44+ def _load_pytorch_models () -> dict [str , dict [str , dict [str , str ]]]:
45+ """Load URLs and commit hashes for available models."""
46+ urls = Path (_get_dlclibrary_path ()) / "dlcmodelzoo" / "modelzoo_urls_pytorch.yaml"
47+ with open (urls ) as file :
48+ data = YAML (pure = True ).load (file )
49+
50+ return data
51+
52+
53+ def _load_pytorch_dataset_models (dataset : str ) -> dict [str , dict [str , str ]]:
4454 """Load URLs and commit hashes for available models."""
45- from ruamel .yaml import YAML
55+ models = _load_pytorch_models ()
56+ if not dataset in models :
57+ raise ValueError (
58+ f"Could not find any models for { dataset } . Models are available for "
59+ f"{ list (models .keys ())} "
60+ )
4661
62+ return models [dataset ]
63+
64+
65+ def _load_model_names ():
66+ """Load URLs and commit hashes for available models."""
4767 fn = os .path .join (_get_dlclibrary_path (), "dlcmodelzoo" , "modelzoo_urls.yaml" )
4868 with open (fn ) as file :
49- return YAML ().load (file )
69+ model_names = YAML ().load (file )
70+
71+ # add PyTorch models
72+ for dataset , model_types in _load_pytorch_models ().items ():
73+ for model_type , models in model_types .items ():
74+ for model , url in models .items ():
75+ model_names [f"{ dataset } _{ model } " ] = url
76+
77+ return model_names
5078
5179
5280def parse_available_supermodels ():
5381 libpath = _get_dlclibrary_path ()
5482 json_path = os .path .join (libpath , "dlcmodelzoo" , "superanimal_models.json" )
5583 with open (json_path ) as file :
56- return json .load (file )
84+ super_animal_models = json .load (file )
85+ return super_animal_models
86+
87+
88+ def get_available_detectors (dataset : str ) -> list [str ]:
89+ """
90+ Returns:
91+ The detectors available for the dataset.
92+ """
93+ return list (_load_pytorch_dataset_models (dataset )["detectors" ].keys ())
94+
95+
96+ def get_available_models (dataset : str ) -> list [str ]:
97+ """
98+ Returns:
99+ The pose models available for the dataset.
100+ """
101+ return list (_load_pytorch_dataset_models (dataset )["pose_models" ].keys ())
57102
58103
59104def _handle_downloaded_file (
0 commit comments