Skip to content

Commit 8a5525f

Browse files
committed
Remove sharded ckpt from export_llama
Sharded checkpoint isn't used anymore; removing it and simplifying export_llama. Differential Revision: [D87828518](https://our.internmc.facebook.com/intern/diff/D87828518/) [ghstack-poisoned]
1 parent 350ea3c commit 8a5525f

File tree

3 files changed

+1
-49
lines changed

3 files changed

+1
-49
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,6 @@ def build_args_parser() -> argparse.ArgumentParser:
229229
help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights.",
230230
)
231231

232-
parser.add_argument(
233-
"--checkpoint_dir",
234-
default=None,
235-
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
236-
)
237-
238232
parser.add_argument(
239233
"--adapter_checkpoint",
240234
required=False,

examples/models/llama/model.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -71,38 +71,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
7171
# The example is using a dummy small model with random weights for demo purpose only.
7272
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7373
device = "cpu"
74-
# flake8: noqa: TOR102
75-
cps = []
76-
# Load sharded checkpoint.
77-
checkpoint = {}
78-
if checkpoint_dir is not None:
79-
# Load multiple checkpoint; ignore the single path.
80-
checkpoint_path = None
81-
for i in range(4):
82-
cp_name = f"consolidated.{i}.pth"
83-
print(f"Loading {cp_name}")
84-
cps.append(
85-
torch.load(
86-
os.path.join(checkpoint_dir, cp_name),
87-
map_location=device,
88-
mmap=True,
89-
)
90-
)
91-
checkpoint = {}
92-
for key in cps[0].keys():
93-
if not torch.allclose(cps[0][key], cps[1][key]):
94-
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
95-
if "wo" in key or "w2" in key:
96-
# Concat on dim=1 for "wo" and "w2".
97-
checkpoint[key] = torch.cat(values, dim=1)
98-
else:
99-
# Concat on dim=0 for everything else.
100-
checkpoint[key] = torch.cat(values, dim=0)
101-
else:
102-
# Do not duplicate layers shared between each checkpoint.
103-
checkpoint[key] = cps[0][key]
104-
# Load single checkpoint.
105-
elif checkpoint_path:
74+
if checkpoint_path:
10675
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
10776

10877
# If given checkpoint is fairseq, convert to llama checkpoint.

extension/llm/export/config/llm_config.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class BaseConfig:
7676
If left empty, the model will either be initialized with random weights
7777
if it is a Llama model or the weights will be downloaded from HuggingFace
7878
if it is a non-Llama model.
79-
checkpoint_dir: Path to directory containing sharded checkpoint files.
8079
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
8180
the model has trained LoRA adapters. Must provide
8281
adapter_config.json.
@@ -87,10 +86,6 @@ class BaseConfig:
8786
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
8887
use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled
8988
if set to 0.
90-
fairseq2: For legacy internal use cases, this is safe to ignore.
91-
preq_mode: Legacy option to specify how prequantized weights are loaded.
92-
Going forward, ExecuTorch supports loading weights prequantized through
93-
TorchAo as-is, without any special handling.
9489
preq_group_size: Legacy option to specify the group size of prequantized weights.
9590
preq_embedding_quantize: Legacy option to specify how prequantized embeddings
9691
are loaded.
@@ -99,13 +94,11 @@ class BaseConfig:
9994
model_class: ModelType = ModelType.llama3
10095
params: Optional[str] = None
10196
checkpoint: Optional[str] = None
102-
checkpoint_dir: Optional[str] = None
10397
adapter_checkpoint: Optional[str] = None
10498
adapter_config: Optional[str] = None
10599
tokenizer_path: Optional[str] = None
106100
metadata: Optional[str] = None
107101
use_lora: int = 0
108-
fairseq2: bool = False
109102
preq_mode: Optional[PreqMode] = None
110103
preq_group_size: int = 32
111104
preq_embedding_quantize: str = "8,0"
@@ -527,8 +520,6 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
527520
llm_config.base.params = args.params
528521
if hasattr(args, "checkpoint"):
529522
llm_config.base.checkpoint = args.checkpoint
530-
if hasattr(args, "checkpoint_dir"):
531-
llm_config.base.checkpoint_dir = args.checkpoint_dir
532523
if hasattr(args, "adapter_checkpoint"):
533524
llm_config.base.adapter_checkpoint = args.adapter_checkpoint
534525
if hasattr(args, "adapter_config"):
@@ -539,8 +530,6 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
539530
llm_config.base.metadata = args.metadata
540531
if hasattr(args, "use_lora"):
541532
llm_config.base.use_lora = args.use_lora
542-
if hasattr(args, "fairseq2"):
543-
llm_config.base.fairseq2 = args.fairseq2
544533

545534
# PreqMode settings
546535
if hasattr(args, "preq_mode") and args.preq_mode:

0 commit comments

Comments
 (0)