|
98 | 98 | "phi_4_mini", |
99 | 99 | ] |
100 | 100 | TORCHTUNE_DEFINED_MODELS = ["llama3_2_vision"] |
| 101 | +HUGGING_FACE_REPO_IDS = { |
| 102 | + "qwen2_5": "Qwen/Qwen2.5-1.5B", |
| 103 | + "phi_4_mini": "microsoft/Phi-4-mini-instruct", |
| 104 | +} |
101 | 105 |
|
102 | 106 |
|
103 | 107 | class WeightType(Enum): |
@@ -519,7 +523,53 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: |
519 | 523 | return return_val |
520 | 524 |
|
521 | 525 |
|
| 526 | +def download_and_convert_hf_checkpoint(modelname: str) -> str: |
| 527 | + """ |
| 528 | + Downloads and converts to Meta format a HuggingFace checkpoint. |
| 529 | + """ |
| 530 | + # Build cache path |
| 531 | + cache_subdir = "meta_checkpoints" |
| 532 | + cache_dir = Path.home() / ".cache" / cache_subdir |
| 533 | + cache_dir.mkdir(parents=True, exist_ok=True) |
| 534 | + |
| 535 | + # Use repo name to name the converted file |
| 536 | + repo_id = HUGGING_FACE_REPO_IDS[modelname] |
| 537 | + model_name = repo_id.replace( |
| 538 | + "/", "_" |
| 539 | + ) # e.g., "bert-base-uncased" or "facebook/bart" → safe filename |
| 540 | + converted_path = cache_dir / f"{model_name}.pth" |
| 541 | + |
| 542 | + if converted_path.exists(): |
| 543 | + print(f"✔ Using cached converted model: {converted_path}") |
| 544 | + return converted_path |
| 545 | + |
| 546 | + # 1. Download weights from Hugging Face. |
| 547 | + print("⬇ Downloading and converting checkpoint...") |
| 548 | + from huggingface_hub import snapshot_download |
| 549 | + |
| 550 | + checkpoint_path = snapshot_download( |
| 551 | + repo_id=repo_id, |
| 552 | + ) |
| 553 | + |
| 554 | + # 2. Convert weights to Meta format. |
| 555 | + if modelname == "qwen2_5": |
| 556 | + from executorch.examples.models.qwen2_5 import convert_weights |
| 557 | + |
| 558 | + convert_weights(checkpoint_path, converted_path) |
| 559 | + elif modelname == "phi_4_mini": |
| 560 | + from executorch.examples.models.phi_4_mini import convert_weights |
| 561 | + |
| 562 | + convert_weights(checkpoint_path, converted_path) |
| 563 | + elif modelname == "smollm2": |
| 564 | + pass |
| 565 | + |
| 566 | + return converted_path |
| 567 | + |
| 568 | + |
522 | 569 | def export_llama(args) -> str: |
| 570 | + if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: |
| 571 | + args.checkpoint = download_and_convert_hf_checkpoint(args.model) |
| 572 | + |
523 | 573 | if args.profile_path is not None: |
524 | 574 | try: |
525 | 575 | from executorch.util.python_profiler import CProfilerFlameGraph |
|
0 commit comments