Skip to content

Commit 997120e

Browse files
committed
use hugging face's cache system instead of custom file management
1 parent 05d8580 commit 997120e

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

.gitignore

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,4 @@ lightning_logs
176176
logs
177177
.isort.cfg
178178
/.vscode
179-
/api/api_models
180-
/api/.api_models
181179
/api/.cloned_repos

api/cli.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,8 @@ def predict(smiles, smiles_file, output, model_type):
6666
current_dir = Path(__file__).resolve().parent
6767

6868
if "hugging_face" in model_config:
69-
local_file_path = download_model_files(
70-
model_config["hugging_face"],
71-
current_dir / ".api_models" / model_type,
72-
)
69+
print(f"For model type `{model_type}` following files are used:")
70+
local_file_path = download_model_files(model_config["hugging_face"])
7371
predictor_kwargs["ckpt_path"] = local_file_path["ckpt"]
7472
predictor_kwargs["target_labels_path"] = local_file_path["labels"]
7573

api/hugging_face.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,48 @@
1-
import shutil
1+
"""
2+
Hugging Face Api:
3+
- For Windows Users check: https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache#limitations
4+
5+
Refer for Hugging Face Hub caching and versioning documentation:
6+
https://huggingface.co/docs/huggingface_hub/en/guides/download
7+
https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache
8+
"""
9+
210
from pathlib import Path
311

412
from huggingface_hub import hf_hub_download
513

614

7-
def download_model_files(model_config: dict, download_path: Path):
15+
def download_model_files(
16+
model_config: dict[str, str | dict[str, str]],
17+
) -> dict[str, Path]:
18+
"""
19+
Downloads specified model files from a Hugging Face Hub repository using hf_hub_download.
20+
21+
Hugging Face Hub provides internal caching and versioning, so file management or duplication
22+
checks are not required.
23+
24+
Args:
25+
model_config (Dict[str, str | Dict[str, str]]): A dictionary containing:
26+
- 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname').
27+
- 'subfolder' (str): The subfolder within the repo where the files are located.
28+
- 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to
29+
actual file names (e.g., 'electra.ckpt', 'classes.txt').
30+
31+
Returns:
32+
Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file.
33+
"""
834
repo_id = model_config["repo_id"]
935
subfolder = model_config["subfolder"]
1036
filenames = model_config["files"]
1137

12-
local_paths = {}
38+
local_paths: dict[str, Path] = {}
1339
for file_type, filename in filenames.items():
14-
local_file_path = download_path / filename
15-
if local_file_path.exists():
16-
print(f"File already exists: {local_file_path}")
17-
local_paths[file_type] = local_file_path
18-
continue
19-
20-
print(
21-
f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}"
22-
)
23-
downloaded_file = hf_hub_download(
40+
downloaded_file_path = hf_hub_download(
2441
repo_id=repo_id,
2542
filename=filename,
2643
subfolder=subfolder,
2744
)
28-
29-
local_file_path.parent.mkdir(parents=True, exist_ok=True)
30-
shutil.move(downloaded_file, local_file_path)
31-
print(f"Saved to: {local_file_path}")
32-
local_paths[file_type] = local_file_path
45+
local_paths[file_type] = Path(downloaded_file_path)
46+
print(f"\t Using file `{filename}` from: {downloaded_file_path}")
3347

3448
return local_paths

0 commit comments

Comments
 (0)