1212
1313import json
1414import os
15+ import tarfile
16+ from pathlib import Path
17+
18+ from huggingface_hub import hf_hub_download
19+ from ruamel .yaml .comments import CommentedBase
1520
1621# just expand this list when adding new models:
1722MODELOPTIONS = [
@@ -52,34 +57,54 @@ def parse_available_supermodels():
5257 return json .load (file )
5358
5459
60+ def _handle_downloaded_file (
61+ file_path : str , target_dir : str , rename_mapping : dict | None = None
62+ ):
63+ """Handle the downloaded file from HuggingFace"""
64+ file_name = os .path .basename (file_path )
65+ try :
66+ with tarfile .open (file_path , mode = "r:gz" ) as tar :
67+ for member in tar :
68+ if not member .isdir ():
69+ fname = Path (member .name ).name
70+ tar .makefile (member , os .path .join (target_dir , fname ))
71+ except tarfile .ReadError : # The model is a .pt file
72+ if rename_mapping is not None :
73+ file_name = rename_mapping .get (file_name , file_name )
74+ if os .path .islink (file_path ):
75+ file_path_ = os .readlink (file_path )
76+ if not os .path .isabs (file_path_ ):
77+ file_path_ = os .path .abspath (
78+ os .path .join (os .path .dirname (file_path ), file_path_ )
79+ )
80+ file_path = file_path_
81+ os .rename (file_path , os .path .join (target_dir , file_name ))
82+
83+
5584def download_huggingface_model (
56- modelname , target_dir = "." , remove_hf_folder = True , rename_mapping : dict | None = None
85+ model_name : str ,
86+ target_dir : str = "." ,
87+ remove_hf_folder : bool = True ,
88+ rename_mapping : dict | None = None ,
5789):
5890 """
59- Download a DeepLabCut Model Zoo Project from Hugging Face
60-
61- Parameters
62- ----------
63- modelname : string
64- Name of the ModelZoo model. For visualizations see: http://www.mackenziemathislab.org/dlc-modelzoo
65- target_dir : directory (as string)
66- Directory where to store the model weights and pose_cfg.yaml file
67- remove_hf_folder : bool, default True
68- Whether to remove the directory structure provided by HuggingFace after downloading and decompressing data into DeepLabCut format.
69- rename_mapping : dict, default None
70- Dictionary to rename the downloaded file. If None, the original filename is used.
91+ Downloads a DeepLabCut Model Zoo Project from Hugging Face.
92+
93+ Args:
94+ model_name (str): Name of the ModelZoo model.
95+ For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
96+ target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
97+ remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
98+ after downloading and decompressing the data into DeepLabCut format. Defaults to True.
99+ rename_mapping (dict, optional): A dictionary to rename the downloaded file.
100+ If None, the original filename is used. Defaults to None.
71101 """
72- from huggingface_hub import hf_hub_download
73- import tarfile
74- from pathlib import Path
75- from ruamel .yaml .comments import CommentedBase
76-
77- neturls = _load_model_names ()
78- if modelname not in neturls :
79- raise ValueError (f"`modelname` should be one of: { ', ' .join (modelname )} ." )
102+ net_urls = _load_model_names ()
103+ if model_name not in net_urls :
104+ raise ValueError (f"`modelname` should be one of: { ', ' .join (net_urls )} ." )
80105
81- print ("Loading...." , modelname )
82- urls = neturls [ modelname ]
106+ print ("Loading...." , model_name )
107+ urls = net_urls [ model_name ]
83108 if isinstance (urls , CommentedBase ):
84109 urls = list (urls )
85110 else :
@@ -98,26 +123,10 @@ def download_huggingface_model(
98123 hf_folder = f"models--{ url [0 ]} --{ url [1 ]} "
99124 path_ = os .path .join (target_dir , hf_folder , "snapshots" )
100125 commit = os .listdir (path_ )[0 ]
101- filename = os .path .join (path_ , commit , targzfn )
102- try :
103- with tarfile .open (filename , mode = "r:gz" ) as tar :
104- for member in tar :
105- if not member .isdir ():
106- fname = Path (member .name ).name
107- tar .makefile (member , os .path .join (target_dir , fname ))
108- except tarfile .ReadError : # The model is a .pt file
109- if rename_mapping is not None :
110- targzfn = rename_mapping .get (targzfn , targzfn )
111- if os .path .islink (filename ):
112- filename_ = os .readlink (filename )
113- if not os .path .isabs (filename_ ):
114- filename_ = os .path .abspath (os .path .join (os .path .dirname (filename ), filename_ ))
115- filename = filename_
116- os .rename (filename , os .path .join (target_dir , targzfn ))
126+ file_name = os .path .join (path_ , commit , targzfn )
127+ _handle_downloaded_file (file_name , target_dir , rename_mapping )
117128
118129 if remove_hf_folder :
119130 import shutil
120131
121132 shutil .rmtree (os .path .join (target_dir , hf_folder ))
122-
123- '../../blobs/6c9c66d48f25cac9f8adaea7a485b07f4bd781ba656785bc4e077d9064e8e5df'
0 commit comments