|
13 | 13 | """Functions for generating ECR image URIs for pre-built SageMaker Docker images.""" |
14 | 14 | from __future__ import absolute_import |
15 | 15 |
|
| 16 | +import os |
16 | 17 | from typing import Optional |
| 18 | +import importlib.util |
17 | 19 |
|
18 | 20 | import urllib.request |
19 | 21 | from urllib.error import HTTPError, URLError |
@@ -123,3 +125,26 @@ def get_huggingface_model_metadata(model_id: str, hf_hub_token: Optional[str] = |
123 | 125 | "Did not find model metadata for the following HuggingFace Model ID %s" % model_id |
124 | 126 | ) |
125 | 127 | return hf_model_metadata_json |
| 128 | + |
| 129 | + |
| 130 | +def download_huggingface_model_metadata( |
| 131 | + model_id: str, model_local_path: str, hf_hub_token: Optional[str] = None |
| 132 | +) -> None: |
| 133 | + """Downloads the HuggingFace Model snapshot via HuggingFace API. |
| 134 | +
|
| 135 | + Args: |
| 136 | + model_id (str): The HuggingFace Model ID |
| 137 | + model_local_path (str): The local path to save the HuggingFace Model snapshot. |
| 138 | + hf_hub_token (str): The HuggingFace Hub Token |
| 139 | +
|
| 140 | + Raises: |
| 141 | + ImportError: If huggingface_hub is not installed. |
| 142 | + """ |
| 143 | + if not importlib.util.find_spec("huggingface_hub"): |
| 144 | + raise ImportError("Unable to import huggingface_hub, check if huggingface_hub is installed") |
| 145 | + |
| 146 | + from huggingface_hub import snapshot_download |
| 147 | + |
| 148 | + os.makedirs(model_local_path, exist_ok=True) |
| 149 | + logger.info("Downloading model %s from Hugging Face Hub to %s", model_id, model_local_path) |
| 150 | + snapshot_download(repo_id=model_id, local_dir=model_local_path, token=hf_hub_token) |
0 commit comments