Skip to content

Commit 503799e

Browse files
committed
Update on "Remove sharded ckpt from export_llama"
Sharded checkpoint isn't used anymore; removing it and simplifying export_llama. Differential Revision: [D87828518](https://our.internmc.facebook.com/intern/diff/D87828518/) [ghstack-poisoned]
1 parent 8c7c450 commit 503799e

File tree

2 files changed

+1
-6
lines changed

2 files changed

+1
-6
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -672,18 +672,12 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
672672
if llm_config.base.checkpoint
673673
else None
674674
)
675-
checkpoint_dir = (
676-
canonical_path(llm_config.base.checkpoint_dir)
677-
if llm_config.base.checkpoint_dir
678-
else None
679-
)
680675
params_path = (
681676
canonical_path(llm_config.base.params) if llm_config.base.params else None
682677
)
683678
output_dir_path = canonical_path(llm_config.export.output_dir, dir=True)
684679

685680
llm_config.base.checkpoint = checkpoint_path
686-
llm_config.base.checkpoint_dir = checkpoint_dir
687681
llm_config.base.params = params_path
688682
llm_config.export.output_dir = output_dir_path
689683

examples/models/llama/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
6767
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
6868
device = "cpu"
6969
# flake8: noqa: TOR102
70+
checkpoint = {}
7071
if checkpoint_path:
7172
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
7273

0 commit comments

Comments
 (0)