Skip to content

Commit a4f5f85

Browse files
committed
add hugging face api
1 parent 6faf3bd commit a4f5f85

File tree

1 file changed

+18
-40
lines changed

1 file changed

+18
-40
lines changed

api/hugging_face.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,54 +3,32 @@
33

44
from huggingface_hub import hf_hub_download
55

6-
# Updated registry: use a list of filenames if you're downloading a folder
7-
MODEL_REGISTRY = {
8-
"electra": {
9-
"repo_id": "aditya0by0/python-chebifier",
10-
"subfolder": "electra",
11-
"filenames": ["electra.ckpt", "classes.txt"],
12-
}
13-
}
146

15-
DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models"
16-
17-
18-
def download_model(model_name):
19-
if model_name not in MODEL_REGISTRY:
20-
raise ValueError(
21-
f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}"
22-
)
23-
24-
model_info = MODEL_REGISTRY[model_name]
25-
repo_id = model_info["repo_id"]
26-
subfolder = model_info["subfolder"]
27-
filenames = model_info["filenames"]
28-
29-
local_paths = []
30-
for filename in filenames:
31-
local_model_path = DOWNLOAD_PATH / model_name / filename
32-
if local_model_path.exists():
33-
print(f"File already exists: {local_model_path}")
34-
local_paths.append(local_model_path)
7+
def download_model_files(model_config: dict, download_path: Path):
8+
repo_id = model_config["repo_id"]
9+
subfolder = model_config["subfolder"]
10+
filenames = model_config["files"]
11+
12+
local_paths = {}
13+
for file_type, filename in filenames.items():
14+
local_file_path = download_path / filename
15+
if local_file_path.exists():
16+
print(f"File already exists: {local_file_path}")
17+
local_paths[file_type] = local_file_path
3518
continue
3619

37-
print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})")
20+
print(
21+
f"Downloading file from: https://huggingface.co/{repo_id}/{subfolder}/{filename}"
22+
)
3823
downloaded_file = hf_hub_download(
3924
repo_id=repo_id,
4025
filename=filename,
4126
subfolder=subfolder,
4227
)
4328

44-
local_model_path.parent.mkdir(parents=True, exist_ok=True)
45-
shutil.move(downloaded_file, local_model_path)
46-
print(f"Saved to: {local_model_path}")
47-
local_paths.append(local_model_path)
29+
local_file_path.parent.mkdir(parents=True, exist_ok=True)
30+
shutil.move(downloaded_file, local_file_path)
31+
print(f"Saved to: {local_file_path}")
32+
local_paths[file_type] = local_file_path
4833

4934
return local_paths
50-
51-
52-
if __name__ == "__main__":
53-
paths = download_model("electra")
54-
print("Downloaded files:")
55-
for p in paths:
56-
print(p)

0 commit comments

Comments
 (0)