Skip to content

Commit 144a37d

Browse files
authored
Remove sharded ckpt from export_llama (#15968)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #15968 Sharded checkpoint isn't used anymore; removing it and simplifying export_llama. Differential Revision: [D87828518](https://our.internmc.facebook.com/intern/diff/D87828518/)
1 parent 33ec615 commit 144a37d

File tree

3 files changed

+2
-52
lines changed

3 files changed

+2
-52
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,6 @@ def build_args_parser() -> argparse.ArgumentParser:
229229
help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights.",
230230
)
231231

232-
parser.add_argument(
233-
"--checkpoint_dir",
234-
default=None,
235-
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
236-
)
237-
238232
parser.add_argument(
239233
"--adapter_checkpoint",
240234
required=False,
@@ -678,18 +672,12 @@ def _prepare_for_llama_export(llm_config: LlmConfig) -> LLMEdgeManager:
678672
if llm_config.base.checkpoint
679673
else None
680674
)
681-
checkpoint_dir = (
682-
canonical_path(llm_config.base.checkpoint_dir)
683-
if llm_config.base.checkpoint_dir
684-
else None
685-
)
686675
params_path = (
687676
canonical_path(llm_config.base.params) if llm_config.base.params else None
688677
)
689678
output_dir_path = canonical_path(llm_config.export.output_dir, dir=True)
690679

691680
llm_config.base.checkpoint = checkpoint_path
692-
llm_config.base.checkpoint_dir = checkpoint_dir
693681
llm_config.base.params = params_path
694682
llm_config.export.output_dir = output_dir_path
695683

examples/models/llama/model.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
# pyre-unsafe
88

99
import json
10-
import os
11-
from typing import Dict, Optional, Tuple
10+
from typing import Optional
1211

1312
import torch
1413
from executorch.examples.models.checkpoint import (
@@ -18,7 +17,6 @@
1817

1918
from executorch.examples.models.llama.llama_transformer import construct_transformer
2019
from executorch.examples.models.llama.model_args import ModelArgs
21-
from executorch.examples.models.llama.rope import Rope
2220

2321
from executorch.extension.llm.export.config.llm_config import LlmConfig
2422
from torchao.utils import TorchAOBaseTensor
@@ -39,12 +37,9 @@ def convert_to_llama_checkpoint(**kwargs):
3937

4038
class Llama2Model(EagerModelBase):
4139
def __init__(self, llm_config: Optional[LlmConfig] = None):
42-
resource_dir = get_default_model_resource_dir(__file__)
43-
4440
self.llm_config = llm_config if llm_config else LlmConfig()
4541

4642
checkpoint_path = self.llm_config.base.checkpoint
47-
checkpoint_dir = self.llm_config.base.checkpoint_dir
4843
params_path = self.llm_config.base.params
4944

5045
# Adapter checkpoint and config.
@@ -72,37 +67,8 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
7267
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7368
device = "cpu"
7469
# flake8: noqa: TOR102
75-
cps = []
76-
# Load sharded checkpoint.
7770
checkpoint = {}
78-
if checkpoint_dir is not None:
79-
# Load multiple checkpoint; ignore the single path.
80-
checkpoint_path = None
81-
for i in range(4):
82-
cp_name = f"consolidated.{i}.pth"
83-
print(f"Loading {cp_name}")
84-
cps.append(
85-
torch.load(
86-
os.path.join(checkpoint_dir, cp_name),
87-
map_location=device,
88-
mmap=True,
89-
)
90-
)
91-
checkpoint = {}
92-
for key in cps[0].keys():
93-
if not torch.allclose(cps[0][key], cps[1][key]):
94-
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
95-
if "wo" in key or "w2" in key:
96-
# Concat on dim=1 for "wo" and "w2".
97-
checkpoint[key] = torch.cat(values, dim=1)
98-
else:
99-
# Concat on dim=0 for everything else.
100-
checkpoint[key] = torch.cat(values, dim=0)
101-
else:
102-
# Do not duplicate layers shared between each checkpoint.
103-
checkpoint[key] = cps[0][key]
104-
# Load single checkpoint.
105-
elif checkpoint_path:
71+
if checkpoint_path:
10672
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
10773

10874
# If given checkpoint is fairseq, convert to llama checkpoint.

extension/llm/export/config/llm_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class BaseConfig:
7676
If left empty, the model will either be initialized with random weights
7777
if it is a Llama model or the weights will be downloaded from HuggingFace
7878
if it is a non-Llama model.
79-
checkpoint_dir: Path to directory containing sharded checkpoint files.
8079
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
8180
the model has trained LoRA adapters. Must provide
8281
adapter_config.json.
@@ -99,7 +98,6 @@ class BaseConfig:
9998
model_class: ModelType = ModelType.llama3
10099
params: Optional[str] = None
101100
checkpoint: Optional[str] = None
102-
checkpoint_dir: Optional[str] = None
103101
adapter_checkpoint: Optional[str] = None
104102
adapter_config: Optional[str] = None
105103
tokenizer_path: Optional[str] = None
@@ -527,8 +525,6 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
527525
llm_config.base.params = args.params
528526
if hasattr(args, "checkpoint"):
529527
llm_config.base.checkpoint = args.checkpoint
530-
if hasattr(args, "checkpoint_dir"):
531-
llm_config.base.checkpoint_dir = args.checkpoint_dir
532528
if hasattr(args, "adapter_checkpoint"):
533529
llm_config.base.adapter_checkpoint = args.adapter_checkpoint
534530
if hasattr(args, "adapter_config"):

0 commit comments

Comments
 (0)