Skip to content

Commit de0f6f1

Browse files
authored
Make export llama checkpoint and param optional
Differential Revision: D71404805 Pull Request resolved: #9456
1 parent 38851a1 commit de0f6f1

File tree

4 files changed

+51
-44
lines changed

4 files changed

+51
-44
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -124,26 +124,19 @@ def verbose_export():
124124

125125

126126
def build_model(
127-
modelname: str = "llama3",
128-
extra_opts: str = "",
129-
*,
130-
par_local_output: bool = False,
131-
resource_pkg_name: str = __name__,
127+
model: str,
128+
checkpoint: str,
129+
params: str,
130+
output_dir: Optional[str] = ".",
131+
extra_opts: Optional[str] = "",
132132
) -> str:
133-
if False: # par_local_output:
134-
output_dir_path = "par:."
135-
else:
136-
output_dir_path = "."
137-
138-
argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
133+
argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
139134
parser = build_args_parser()
140135
args = parser.parse_args(shlex.split(argString))
141-
# pkg_name = resource_pkg_name
142136
return export_llama(args)
143137

144138

145139
def build_args_parser() -> argparse.ArgumentParser:
146-
ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
147140
parser = argparse.ArgumentParser()
148141
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
149142
# parser.add_argument(
@@ -192,8 +185,8 @@ def build_args_parser() -> argparse.ArgumentParser:
192185
parser.add_argument(
193186
"-c",
194187
"--checkpoint",
195-
default=f"{ckpt_dir}/params/demo_rand_params.pth",
196-
help="checkpoint path",
188+
required=False,
189+
help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights.",
197190
)
198191

199192
parser.add_argument(
@@ -274,8 +267,8 @@ def build_args_parser() -> argparse.ArgumentParser:
274267
parser.add_argument(
275268
"-p",
276269
"--params",
277-
default=f"{ckpt_dir}/params/demo_config.json",
278-
help="config.json",
270+
required=False,
271+
help="Config file for model parameters. When not provided, the model will fallback on default values defined in examples/models/llama/model_args.py.",
279272
)
280273
parser.add_argument(
281274
"--optimized_rotation_path",
@@ -562,7 +555,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
562555
checkpoint_dir = (
563556
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
564557
)
565-
params_path = canonical_path(args.params)
558+
params_path = canonical_path(args.params) if args.params else None
566559
output_dir_path = canonical_path(args.output_dir, dir=True)
567560
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
568561

@@ -985,7 +978,7 @@ def _load_llama_model(
985978
*,
986979
checkpoint: Optional[str] = None,
987980
checkpoint_dir: Optional[str] = None,
988-
params_path: str,
981+
params_path: Optional[str] = None,
989982
use_kv_cache: bool = False,
990983
use_sdpa_with_kv_cache: bool = False,
991984
generate_full_logits: bool = False,
@@ -1012,13 +1005,6 @@ def _load_llama_model(
10121005
An instance of LLMEdgeManager which contains the eager mode model.
10131006
"""
10141007

1015-
assert (
1016-
checkpoint or checkpoint_dir
1017-
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
1018-
logging.info(
1019-
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
1020-
)
1021-
10221008
if modelname in EXECUTORCH_DEFINED_MODELS:
10231009
module_name = "llama"
10241010
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.

examples/models/llama/model.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,13 @@ def __init__(self, **kwargs):
3838
resource_dir = get_default_model_resource_dir(__file__)
3939

4040
# Use single checkpoint file.
41-
checkpoint_path = kwargs.get(
42-
"checkpoint", resource_dir / "demo_rand_params.pth"
43-
)
44-
params_path = kwargs.get("params", resource_dir / "demo_config.json")
45-
41+
checkpoint_path = kwargs.get("checkpoint", None)
4642
# Check if checkpoint_dir was provided for a sharded checkpoint.
4743
checkpoint_dir = kwargs.get("checkpoint_dir", None)
4844

45+
# Params file.
46+
params_path = kwargs.get("params", None)
47+
4948
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5049
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5150
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -66,6 +65,7 @@ def __init__(self, **kwargs):
6665
# flake8: noqa: TOR102
6766
cps = []
6867
# Load sharded checkpoint.
68+
checkpoint = {}
6969
if checkpoint_dir is not None:
7070
# Load multiple checkpoint; ignore the single path.
7171
checkpoint_path = None
@@ -93,7 +93,7 @@ def __init__(self, **kwargs):
9393
# Do not duplicate layers shared between each checkpoint.
9494
checkpoint[key] = cps[0][key]
9595
# Load single checkpoint.
96-
else:
96+
elif checkpoint_path:
9797
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
9898

9999
# If given checkpoint is fairseq, convert to llama checkpoint.
@@ -122,8 +122,12 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
with open(params_path, "r") as f:
126-
params = json.loads(f.read())
125+
# Get optional params.
126+
params = {}
127+
if params_path:
128+
with open(params_path, "r") as f:
129+
params = json.loads(f.read())
130+
127131
output_prune_map = None
128132
if self.output_prune_map_path is not None:
129133
with open(self.output_prune_map_path, "r") as f:
@@ -170,7 +174,11 @@ def __init__(self, **kwargs):
170174
with torch.device("meta"):
171175
# Model itself is loaded in default dtype, fp32.
172176
self.model_ = Transformer(model_args)
173-
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
177+
# Get checkpoint dtype.
178+
if checkpoint:
179+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
180+
else:
181+
self.model_.checkpoint_dtype = None
174182

175183
if "int8" in str(checkpoint_path):
176184
print("Using int8 weight-only quantization!")
@@ -244,16 +252,19 @@ def __init__(self, **kwargs):
244252
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245253
# by default initialized to fp32. This is fine because every other supported type
246254
# losslessly converts to fp32, so we don't lose precision here.
247-
missing, unexpected = self.model_.load_state_dict(
248-
checkpoint,
249-
strict=False,
250-
assign=True,
251-
) # self.model_ = Transformer(gptconf)
255+
if checkpoint:
256+
missing, unexpected = self.model_.load_state_dict(
257+
checkpoint,
258+
strict=False,
259+
assign=True,
260+
) # self.model_ = Transformer(gptconf)
261+
else:
262+
print("Checkpoint not provided, defaulting to uninitialized weights.")
263+
self.model_.to_empty(device="cpu")
252264
except RuntimeError as e:
253265
print(
254-
"Could not load checkpoint into mode, defaulting to random uninitialized weights."
266+
f"Could not load checkpoint into mode and will default to uninitialized weights due to error: {e}."
255267
)
256-
print(f"Error: {e}")
257268
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
258269
self.model_.to_empty(device="cpu")
259270

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ModelArgs:
88
n_layers: int = 32
99
n_heads: int = 32
1010
n_kv_heads: Optional[int] = None
11-
vocab_size: int = -1 # defined later by tokenizer
11+
vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer.
1212
hidden_dim: Optional[int] = None
1313
head_dim: Optional[int] = None # Optional customized head_dim
1414
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2

examples/models/llava/export_llava.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,17 @@ def forward(self, input_pos, embeddings):
9898
dtype_override = DType.fp32
9999
parser = build_args_parser()
100100
args = parser.parse_args(
101-
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
101+
[
102+
"-p",
103+
"params.json",
104+
"-X",
105+
"-qmode",
106+
"8da4w",
107+
"--group_size",
108+
"128",
109+
"--embedding-quantize",
110+
"4,32",
111+
]
102112
)
103113
quant_transform = get_quant_weight_transform(args, dtype_override)
104114
_, quantizers, _ = get_quantizer_and_quant_params(args)

0 commit comments

Comments
 (0)