Skip to content

Commit d538659

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Make export llama checkpoint and param optional (#9456)
Summary: Refactor to make checkpoint path and param path optional. Reviewed By: larryliu0820 Differential Revision: D71404805
1 parent 6f6fa6a commit d538659

File tree

4 files changed

+52
-44
lines changed

4 files changed

+52
-44
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,19 @@ def verbose_export():
123123

124124

125125
def build_model(
126-
modelname: str = "llama3",
127-
extra_opts: str = "",
128-
*,
129-
par_local_output: bool = False,
130-
resource_pkg_name: str = __name__,
126+
model: str,
127+
checkpoint: str,
128+
params: str,
129+
output_dir: Optional[str] = ".",
130+
extra_opts: Optional[str] = "",
131131
) -> str:
132-
if False: # par_local_output:
133-
output_dir_path = "par:."
134-
else:
135-
output_dir_path = "."
136-
137-
argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
132+
argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
138133
parser = build_args_parser()
139134
args = parser.parse_args(shlex.split(argString))
140-
# pkg_name = resource_pkg_name
141135
return export_llama(args)
142136

143137

144138
def build_args_parser() -> argparse.ArgumentParser:
145-
ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
146139
parser = argparse.ArgumentParser()
147140
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
148141
# parser.add_argument(
@@ -191,8 +184,8 @@ def build_args_parser() -> argparse.ArgumentParser:
191184
parser.add_argument(
192185
"-c",
193186
"--checkpoint",
194-
default=f"{ckpt_dir}/params/demo_rand_params.pth",
195-
help="checkpoint path",
187+
required=False,
188+
help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights.",
196189
)
197190

198191
parser.add_argument(
@@ -273,8 +266,8 @@ def build_args_parser() -> argparse.ArgumentParser:
273266
parser.add_argument(
274267
"-p",
275268
"--params",
276-
default=f"{ckpt_dir}/params/demo_config.json",
277-
help="config.json",
269+
required=False,
270+
help="Config file for model parameters. When not provided, the model will fallback on default values defined in examples/models/llama/model_args.py.",
278271
)
279272
parser.add_argument(
280273
"--optimized_rotation_path",
@@ -561,7 +554,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561554
checkpoint_dir = (
562555
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
563556
)
564-
params_path = canonical_path(args.params)
557+
params_path = canonical_path(args.params) if args.params else None
565558
output_dir_path = canonical_path(args.output_dir, dir=True)
566559
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567560

@@ -960,7 +953,7 @@ def _load_llama_model(
960953
*,
961954
checkpoint: Optional[str] = None,
962955
checkpoint_dir: Optional[str] = None,
963-
params_path: str,
956+
params_path: Optional[str] = None,
964957
use_kv_cache: bool = False,
965958
use_sdpa_with_kv_cache: bool = False,
966959
generate_full_logits: bool = False,
@@ -987,13 +980,6 @@ def _load_llama_model(
987980
An instance of LLMEdgeManager which contains the eager mode model.
988981
"""
989982

990-
assert (
991-
checkpoint or checkpoint_dir
992-
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
993-
logging.info(
994-
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
995-
)
996-
997983
if modelname in EXECUTORCH_DEFINED_MODELS:
998984
module_name = "llama"
999985
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.

examples/models/llama/model.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,13 @@ def __init__(self, **kwargs):
3838
resource_dir = get_default_model_resource_dir(__file__)
3939

4040
# Use single checkpoint file.
41-
checkpoint_path = kwargs.get(
42-
"checkpoint", resource_dir / "demo_rand_params.pth"
43-
)
44-
params_path = kwargs.get("params", resource_dir / "demo_config.json")
45-
41+
checkpoint_path = kwargs.get("checkpoint", None)
4642
# Check if checkpoint_dir was provided for a sharded checkpoint.
4743
checkpoint_dir = kwargs.get("checkpoint_dir", None)
4844

45+
# Params file.
46+
params_path = kwargs.get("params", None)
47+
4948
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5049
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5150
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -66,6 +65,7 @@ def __init__(self, **kwargs):
6665
# flake8: noqa: TOR102
6766
cps = []
6867
# Load sharded checkpoint.
68+
checkpoint = {}
6969
if checkpoint_dir is not None:
7070
# Load multiple checkpoint; ignore the single path.
7171
checkpoint_path = None
@@ -93,7 +93,7 @@ def __init__(self, **kwargs):
9393
# Do not duplicate layers shared between each checkpoint.
9494
checkpoint[key] = cps[0][key]
9595
# Load single checkpoint.
96-
else:
96+
elif checkpoint_path:
9797
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
9898

9999
# If given checkpoint is fairseq, convert to llama checkpoint.
@@ -123,10 +123,17 @@ def __init__(self, **kwargs):
123123
)
124124

125125
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
126+
if checkpoint:
127+
self.dtype = get_checkpoint_dtype(checkpoint)
128+
else:
129+
self.dtype = None
130+
131+
# Get optional params.
132+
params = {}
133+
if params_path:
134+
with open(params_path, "r") as f:
135+
params = json.loads(f.read())
127136

128-
with open(params_path, "r") as f:
129-
params = json.loads(f.read())
130137
output_prune_map = None
131138
if self.output_prune_map_path is not None:
132139
with open(self.output_prune_map_path, "r") as f:
@@ -241,16 +248,21 @@ def __init__(self, **kwargs):
241248
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242249
# Because we are using device="meta", tensors do not have memory associated with them
243250
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
244-
missing, unexpected = self.model_.load_state_dict(
245-
checkpoint,
246-
strict=False,
247-
assign=True,
248-
) # self.model_ = Transformer(gptconf)
251+
if checkpoint:
252+
missing, unexpected = self.model_.load_state_dict(
253+
checkpoint,
254+
strict=False,
255+
assign=True,
256+
) # self.model_ = Transformer(gptconf)
257+
else:
258+
print(
259+
"Checkpoint not provided, defaulting to random uninitialized weights."
260+
)
261+
self.model_.to_empty(device="cpu")
249262
except RuntimeError as e:
250263
print(
251-
"Could not load checkpoint into mode, defaulting to random uninitialized weights."
264+
f"Could not load checkpoint into mode and will default to random uninitialized weights due to error: {e}."
252265
)
253-
print(f"Error: {e}")
254266
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
255267
self.model_.to_empty(device="cpu")
256268

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ModelArgs:
88
n_layers: int = 32
99
n_heads: int = 32
1010
n_kv_heads: Optional[int] = None
11-
vocab_size: int = -1 # defined later by tokenizer
11+
vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer.
1212
hidden_dim: Optional[int] = None
1313
head_dim: Optional[int] = None # Optional customized head_dim
1414
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2

examples/models/llava/export_llava.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,17 @@ def forward(self, input_pos, embeddings):
9898
dtype_override = DType.fp32
9999
parser = build_args_parser()
100100
args = parser.parse_args(
101-
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
101+
[
102+
"-p",
103+
"params.json",
104+
"-X",
105+
"-qmode",
106+
"8da4w",
107+
"--group_size",
108+
"128",
109+
"--embedding-quantize",
110+
"4,32",
111+
]
102112
)
103113
quant_transform = get_quant_weight_transform(args, dtype_override, False)
104114
_, quantizers, _ = get_quantizer_and_quant_params(args)

0 commit comments

Comments
 (0)