5252logger = Logger ()
5353
5454
55- def _download_file_from_hf_if_necessary (
56- file_path : str , local_dir : str , repo_id : str
57- ) -> bool :
58- if not os .path .exists (file_path ):
55+ def _download_file_from_hf_if_necessary (local_dir : str , repo_id : str ) -> bool :
56+ weights_path = os . path . join ( local_dir , MODEL_WEIGHTS_FILE_IN_SAFETENSORS )
57+ config_path = os . path . join ( local_dir , MODEL_CONFIG_FILE_IN_JSON )
58+ if not os .path .exists (weights_path ):
5959 logger .info (
60- f"Model file not found at { file_path } , downloading from HuggingFace..."
60+ f"Model weights file not found at { weights_path } , downloading from HuggingFace..."
6161 )
6262 try :
6363 hf_hub_download (
6464 repo_id = repo_id ,
6565 filename = MODEL_WEIGHTS_FILE_IN_SAFETENSORS ,
6666 local_dir = local_dir ,
6767 )
68- logger .info (f"Got file to { file_path } " )
68+ logger .info (f"Got file to { weights_path } " )
6969 except Exception as e :
70- logger .error (f"Failed to download model file to { local_dir } due to { e } " )
70+ logger .error (
71+ f"Failed to download model weights file to { local_dir } due to { e } "
72+ )
73+ return False
74+ if not os .path .exists (config_path ):
75+ logger .info (
76+ f"Model config file not found at { config_path } , downloading from HuggingFace..."
77+ )
78+ try :
79+ hf_hub_download (
80+ repo_id = repo_id ,
81+ filename = MODEL_CONFIG_FILE_IN_JSON ,
82+ local_dir = local_dir ,
83+ )
84+ logger .info (f"Got file to { config_path } " )
85+ except Exception as e :
86+ logger .error (
87+ f"Failed to download model config file to { local_dir } due to { e } "
88+ )
7189 return False
7290 return True
7391
@@ -82,11 +100,7 @@ def download_built_in_ltsm_from_hf_if_necessary(
82100 bool: True if the model is existed or downloaded successfully, False otherwise.
83101 """
84102 repo_id = TIMER_REPO_ID [model_type ]
85- weights_path = os .path .join (local_dir , MODEL_WEIGHTS_FILE_IN_SAFETENSORS )
86- if not _download_file_from_hf_if_necessary (weights_path , local_dir , repo_id ):
87- return False
88- config_path = os .path .join (local_dir , MODEL_CONFIG_FILE_IN_JSON )
89- if not _download_file_from_hf_if_necessary (config_path , local_dir , repo_id ):
103+ if not _download_file_from_hf_if_necessary (local_dir , repo_id ):
90104 return False
91105 return True
92106
0 commit comments