|
1 | | -import shutil |
| 1 | +""" |
| 2 | +Hugging Face Api: |
| 3 | + - For Windows Users check: https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache#limitations |
| 4 | +
|
| 5 | + Refer for Hugging Face Hub caching and versioning documentation: |
| 6 | + https://huggingface.co/docs/huggingface_hub/en/guides/download |
| 7 | + https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache |
| 8 | +""" |
| 9 | + |
2 | 10 | from pathlib import Path |
3 | 11 |
|
4 | 12 | from huggingface_hub import hf_hub_download |
5 | 13 |
|
6 | 14 |
|
7 | | -def download_model_files(model_config: dict, download_path: Path): |
| 15 | +def download_model_files( |
| 16 | + model_config: dict[str, str | dict[str, str]], |
| 17 | +) -> dict[str, Path]: |
| 18 | + """ |
| 19 | + Downloads specified model files from a Hugging Face Hub repository using hf_hub_download. |
| 20 | +
|
| 21 | + Hugging Face Hub provides internal caching and versioning, so file management or duplication |
| 22 | + checks are not required. |
| 23 | +
|
| 24 | + Args: |
| 25 | + model_config (Dict[str, str | Dict[str, str]]): A dictionary containing: |
| 26 | + - 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname'). |
| 27 | + - 'subfolder' (str): The subfolder within the repo where the files are located. |
| 28 | + - 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to |
| 29 | + actual file names (e.g., 'electra.ckpt', 'classes.txt'). |
| 30 | +
|
| 31 | + Returns: |
| 32 | + Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file. |
| 33 | + """ |
8 | 34 | repo_id = model_config["repo_id"] |
9 | 35 | subfolder = model_config["subfolder"] |
10 | 36 | filenames = model_config["files"] |
11 | 37 |
|
12 | | - local_paths = {} |
| 38 | + local_paths: dict[str, Path] = {} |
13 | 39 | for file_type, filename in filenames.items(): |
14 | | - local_file_path = download_path / filename |
15 | | - if local_file_path.exists(): |
16 | | - print(f"File already exists: {local_file_path}") |
17 | | - local_paths[file_type] = local_file_path |
18 | | - continue |
19 | | - |
20 | | - print( |
21 | | - f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}" |
22 | | - ) |
23 | | - downloaded_file = hf_hub_download( |
| 40 | + downloaded_file_path = hf_hub_download( |
24 | 41 | repo_id=repo_id, |
25 | 42 | filename=filename, |
26 | 43 | subfolder=subfolder, |
27 | 44 | ) |
28 | | - |
29 | | - local_file_path.parent.mkdir(parents=True, exist_ok=True) |
30 | | - shutil.move(downloaded_file, local_file_path) |
31 | | - print(f"Saved to: {local_file_path}") |
32 | | - local_paths[file_type] = local_file_path |
| 45 | + local_paths[file_type] = Path(downloaded_file_path) |
| 46 | + print(f"\t Using file `{filename}` from: {downloaded_file_path}") |
33 | 47 |
|
34 | 48 | return local_paths |
0 commit comments