Skip to content

Commit e1cbccf

Browse files
authored
[Bug Fix] Download config.json for Large TimeSeries Model (#16295)
1 parent 8e8c1a6 commit e1cbccf

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

iotdb-core/ainode/ainode/core/model/built_in_model_factory.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,40 @@
5252
logger = Logger()
5353

5454

55-
def _download_file_from_hf_if_necessary(
56-
file_path: str, local_dir: str, repo_id: str
57-
) -> bool:
58-
if not os.path.exists(file_path):
55+
def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool:
56+
weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
57+
config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON)
58+
if not os.path.exists(weights_path):
5959
logger.info(
60-
f"Model file not found at {file_path}, downloading from HuggingFace..."
60+
f"Model weights file not found at {weights_path}, downloading from HuggingFace..."
6161
)
6262
try:
6363
hf_hub_download(
6464
repo_id=repo_id,
6565
filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
6666
local_dir=local_dir,
6767
)
68-
logger.info(f"Got file to {file_path}")
68+
logger.info(f"Got file to {weights_path}")
6969
except Exception as e:
70-
logger.error(f"Failed to download model file to {local_dir} due to {e}")
70+
logger.error(
71+
f"Failed to download model weights file to {local_dir} due to {e}"
72+
)
73+
return False
74+
if not os.path.exists(config_path):
75+
logger.info(
76+
f"Model config file not found at {config_path}, downloading from HuggingFace..."
77+
)
78+
try:
79+
hf_hub_download(
80+
repo_id=repo_id,
81+
filename=MODEL_CONFIG_FILE_IN_JSON,
82+
local_dir=local_dir,
83+
)
84+
logger.info(f"Got file to {config_path}")
85+
except Exception as e:
86+
logger.error(
87+
f"Failed to download model config file to {local_dir} due to {e}"
88+
)
7189
return False
7290
return True
7391

@@ -82,11 +100,7 @@ def download_built_in_ltsm_from_hf_if_necessary(
82100
bool: True if the model is existed or downloaded successfully, False otherwise.
83101
"""
84102
repo_id = TIMER_REPO_ID[model_type]
85-
weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
86-
if not _download_file_from_hf_if_necessary(weights_path, local_dir, repo_id):
87-
return False
88-
config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON)
89-
if not _download_file_from_hf_if_necessary(config_path, local_dir, repo_id):
103+
if not _download_file_from_hf_if_necessary(local_dir, repo_id):
90104
return False
91105
return True
92106

0 commit comments

Comments
 (0)