88#
99# Licensed under GNU Lesser General Public License v3.0
1010#
11+ from __future__ import annotations
12+
1113import json
1214import os
1315
2022 "mouse_pupil_vclose" ,
2123 "horse_sideview" ,
2224 "full_macaque" ,
23- "superanimal_topviewmouse " ,
24- "superanimal_quadruped " ,
25- "superanimal_quadruped_HRNetw32 " ,
26- "superanimal_quadruped_ATP " ,
25+ "superanimal_topviewmouse_dlcrnet " ,
26+ "superanimal_quadruped_dlcrnet " ,
27+ "superanimal_topviewmouse_hrnetw32 " ,
28+ "superanimal_quadruped_hrnetw32 " ,
2729]
2830
2931
@@ -50,7 +52,9 @@ def parse_available_supermodels():
5052 return json .load (file )
5153
5254
53- def download_huggingface_model (modelname , target_dir = "." , remove_hf_folder = True ):
55+ def download_huggingface_model (
56+ modelname , target_dir = "." , remove_hf_folder = True , rename_mapping : dict | None = None
57+ ):
5458 """
5559 Download a DeepLabCut Model Zoo Project from Hugging Face
5660
@@ -62,41 +66,58 @@ def download_huggingface_model(modelname, target_dir=".", remove_hf_folder=True)
6266 Directory where to store the model weights and pose_cfg.yaml file
6367 remove_hf_folder : bool, default True
6468 Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
69+ rename_mapping : dict, default None
70+ Dictionary to rename the downloaded file. If None, the original filename is used.
6571 """
6672 from huggingface_hub import hf_hub_download
6773 import tarfile
6874 from pathlib import Path
75+ from ruamel .yaml .comments import CommentedBase
6976
7077 neturls = _load_model_names ()
7178 if modelname not in neturls :
7279 raise ValueError (f"`modelname` should be one of: { ', ' .join (modelname )} ." )
7380
7481 print ("Loading...." , modelname )
75- url = neturls [modelname ].split ("/" )
76- repo_id , targzfn = url [0 ] + "/" + url [1 ], str (url [- 1 ])
77-
78- hf_hub_download (repo_id , targzfn , cache_dir = str (target_dir ))
79-
80- # Create a new subfolder as indicated below, unzipping from there and deleting this folder
81- hf_folder = f"models--{ url [0 ]} --{ url [1 ]} "
82- hf_path = os .path .join (
83- hf_folder ,
84- "snapshots" ,
85- str (neturls [modelname + "_commit" ]),
86- targzfn ,
87- )
88-
89- filename = os .path .join (target_dir , hf_path )
90- try :
91- with tarfile .open (filename , mode = "r:gz" ) as tar :
92- for member in tar :
93- if not member .isdir ():
94- fname = Path (member .name ).name
95- tar .makefile (member , os .path .join (target_dir , fname ))
96- except tarfile .ReadError : # The model is a .pt file
97- os .rename (filename , os .path .join (target_dir , targzfn ))
82+ urls = neturls [modelname ]
83+ if isinstance (urls , CommentedBase ):
84+ urls = list (urls )
85+ else :
86+ urls = [urls ]
87+
88+ if not os .path .isabs (target_dir ):
89+ target_dir = os .path .abspath (target_dir )
90+
91+ for url in urls :
92+ url = url .split ("/" )
93+ repo_id , targzfn = url [0 ] + "/" + url [1 ], str (url [- 1 ])
94+
95+ hf_hub_download (repo_id , targzfn , cache_dir = str (target_dir ))
96+
97+ # Create a new subfolder as indicated below, unzipping from there and deleting this folder
98+ hf_folder = f"models--{ url [0 ]} --{ url [1 ]} "
99+ path_ = os .path .join (target_dir , hf_folder , "snapshots" )
100+ commit = os .listdir (path_ )[0 ]
101+ filename = os .path .join (path_ , commit , targzfn )
102+ try :
103+ with tarfile .open (filename , mode = "r:gz" ) as tar :
104+ for member in tar :
105+ if not member .isdir ():
106+ fname = Path (member .name ).name
107+ tar .makefile (member , os .path .join (target_dir , fname ))
108+ except tarfile .ReadError : # The model is a .pt file
109+ if rename_mapping is not None :
110+ targzfn = rename_mapping .get (targzfn , targzfn )
111+ if os .path .islink (filename ):
112+ filename_ = os .readlink (filename )
113+ if not os .path .isabs (filename_ ):
114+ filename_ = os .path .abspath (os .path .join (os .path .dirname (filename ), filename_ ))
115+ filename = filename_
116+ os .rename (filename , os .path .join (target_dir , targzfn ))
98117
99118 if remove_hf_folder :
100119 import shutil
101120
102121 shutil .rmtree (os .path .join (target_dir , hf_folder ))
122+
123+ '../../blobs/6c9c66d48f25cac9f8adaea7a485b07f4bd781ba656785bc4e077d9064e8e5df'
0 commit comments