Skip to content

Commit 02c5409

Browse files
committed
api code to download model from hugging face
1 parent 6300bff commit 02c5409

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
/api/api_models

api/hugging_face.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import shutil
2+
from pathlib import Path
3+
4+
from huggingface_hub import hf_hub_download
5+
6+
# Updated registry: use a list of filenames if you're downloading a folder
7+
MODEL_REGISTRY = {
8+
"electra": {
9+
"repo_id": "aditya0by0/python-chebifier",
10+
"subfolder": "electra",
11+
"filenames": ["electra.ckpt", "classes.txt"],
12+
}
13+
}
14+
15+
DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models"
16+
17+
18+
def download_model(model_name):
19+
if model_name not in MODEL_REGISTRY:
20+
raise ValueError(
21+
f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}"
22+
)
23+
24+
model_info = MODEL_REGISTRY[model_name]
25+
repo_id = model_info["repo_id"]
26+
subfolder = model_info["subfolder"]
27+
filenames = model_info["filenames"]
28+
29+
local_paths = []
30+
for filename in filenames:
31+
local_model_path = DOWNLOAD_PATH / model_name / filename
32+
if local_model_path.exists():
33+
print(f"File already exists: {local_model_path}")
34+
local_paths.append(local_model_path)
35+
continue
36+
37+
print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})")
38+
downloaded_file = hf_hub_download(
39+
repo_id=repo_id,
40+
filename=filename,
41+
subfolder=subfolder,
42+
)
43+
44+
local_model_path.parent.mkdir(parents=True, exist_ok=True)
45+
shutil.move(downloaded_file, local_model_path)
46+
print(f"Saved to: {local_model_path}")
47+
local_paths.append(local_model_path)
48+
49+
return local_paths
50+
51+
52+
if __name__ == "__main__":
53+
paths = download_model("electra")
54+
print("Downloaded files:")
55+
for p in paths:
56+
print(p)

0 commit comments

Comments
 (0)