|
| 1 | +import shutil |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +from huggingface_hub import hf_hub_download |
| 5 | + |
| 6 | +# Updated registry: use a list of filenames if you're downloading a folder |
| 7 | +MODEL_REGISTRY = { |
| 8 | + "electra": { |
| 9 | + "repo_id": "aditya0by0/python-chebifier", |
| 10 | + "subfolder": "electra", |
| 11 | + "filenames": ["electra.ckpt", "classes.txt"], |
| 12 | + } |
| 13 | +} |
| 14 | + |
| 15 | +DOWNLOAD_PATH = Path(__file__).resolve().parent / "api_models" |
| 16 | + |
| 17 | + |
| 18 | +def download_model(model_name): |
| 19 | + if model_name not in MODEL_REGISTRY: |
| 20 | + raise ValueError( |
| 21 | + f"Unknown model name. Available models: {list(MODEL_REGISTRY.keys())}" |
| 22 | + ) |
| 23 | + |
| 24 | + model_info = MODEL_REGISTRY[model_name] |
| 25 | + repo_id = model_info["repo_id"] |
| 26 | + subfolder = model_info["subfolder"] |
| 27 | + filenames = model_info["filenames"] |
| 28 | + |
| 29 | + local_paths = [] |
| 30 | + for filename in filenames: |
| 31 | + local_model_path = DOWNLOAD_PATH / model_name / filename |
| 32 | + if local_model_path.exists(): |
| 33 | + print(f"File already exists: {local_model_path}") |
| 34 | + local_paths.append(local_model_path) |
| 35 | + continue |
| 36 | + |
| 37 | + print(f"Downloading: {repo_id}/{filename} (subfolder: {subfolder})") |
| 38 | + downloaded_file = hf_hub_download( |
| 39 | + repo_id=repo_id, |
| 40 | + filename=filename, |
| 41 | + subfolder=subfolder, |
| 42 | + ) |
| 43 | + |
| 44 | + local_model_path.parent.mkdir(parents=True, exist_ok=True) |
| 45 | + shutil.move(downloaded_file, local_model_path) |
| 46 | + print(f"Saved to: {local_model_path}") |
| 47 | + local_paths.append(local_model_path) |
| 48 | + |
| 49 | + return local_paths |
| 50 | + |
| 51 | + |
| 52 | +if __name__ == "__main__": |
| 53 | + paths = download_model("electra") |
| 54 | + print("Downloaded files:") |
| 55 | + for p in paths: |
| 56 | + print(p) |
0 commit comments