Skip to content

Commit 4be51f9

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Make export llama checkpoint and param optional (#9456)
Summary: Refactor to make checkpoint path and param path optional. Reviewed By: larryliu0820 Differential Revision: D71404805
1 parent 0dd7e4e commit 4be51f9

File tree

4 files changed

+51
-44
lines changed

4 files changed

+51
-44
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,19 @@ def verbose_export():
123123

124124

125125
def build_model(
126-
modelname: str = "llama3",
127-
extra_opts: str = "",
128-
*,
129-
par_local_output: bool = False,
130-
resource_pkg_name: str = __name__,
126+
model: str,
127+
checkpoint: str,
128+
params: str,
129+
output_dir: Optional[str] = ".",
130+
extra_opts: Optional[str] = "",
131131
) -> str:
132-
if False: # par_local_output:
133-
output_dir_path = "par:."
134-
else:
135-
output_dir_path = "."
136-
137-
argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}"
132+
argString = f"--model {model} --checkpoint {checkpoint} --params {params} {extra_opts} --output-dir {output_dir}"
138133
parser = build_args_parser()
139134
args = parser.parse_args(shlex.split(argString))
140-
# pkg_name = resource_pkg_name
141135
return export_llama(args)
142136

143137

144138
def build_args_parser() -> argparse.ArgumentParser:
145-
ckpt_dir = f"{Path(__file__).absolute().parent.as_posix()}"
146139
parser = argparse.ArgumentParser()
147140
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
148141
# parser.add_argument(
@@ -191,8 +184,8 @@ def build_args_parser() -> argparse.ArgumentParser:
191184
parser.add_argument(
192185
"-c",
193186
"--checkpoint",
194-
default=f"{ckpt_dir}/params/demo_rand_params.pth",
195-
help="checkpoint path",
187+
required=False,
188+
help="Path to the checkpoint .pth file. When not provided, the model will be initialized with random weights.",
196189
)
197190

198191
parser.add_argument(
@@ -273,8 +266,8 @@ def build_args_parser() -> argparse.ArgumentParser:
273266
parser.add_argument(
274267
"-p",
275268
"--params",
276-
default=f"{ckpt_dir}/params/demo_config.json",
277-
help="config.json",
269+
required=False,
270+
help="Config file for model parameters. When not provided, the model will fallback on default values defined in examples/models/llama/model_args.py.",
278271
)
279272
parser.add_argument(
280273
"--optimized_rotation_path",
@@ -561,7 +554,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
561554
checkpoint_dir = (
562555
canonical_path(args.checkpoint_dir) if args.checkpoint_dir else None
563556
)
564-
params_path = canonical_path(args.params)
557+
params_path = canonical_path(args.params) if args.params else None
565558
output_dir_path = canonical_path(args.output_dir, dir=True)
566559
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567560

@@ -959,7 +952,7 @@ def _load_llama_model(
959952
*,
960953
checkpoint: Optional[str] = None,
961954
checkpoint_dir: Optional[str] = None,
962-
params_path: str,
955+
params_path: Optional[str] = None,
963956
use_kv_cache: bool = False,
964957
use_sdpa_with_kv_cache: bool = False,
965958
generate_full_logits: bool = False,
@@ -986,13 +979,6 @@ def _load_llama_model(
986979
An instance of LLMEdgeManager which contains the eager mode model.
987980
"""
988981

989-
assert (
990-
checkpoint or checkpoint_dir
991-
) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty"
992-
logging.info(
993-
f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}"
994-
)
995-
996982
if modelname in EXECUTORCH_DEFINED_MODELS:
997983
module_name = "llama"
998984
model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py.

examples/models/llama/model.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,13 @@ def __init__(self, **kwargs):
3838
resource_dir = get_default_model_resource_dir(__file__)
3939

4040
# Use single checkpoint file.
41-
checkpoint_path = kwargs.get(
42-
"checkpoint", resource_dir / "demo_rand_params.pth"
43-
)
44-
params_path = kwargs.get("params", resource_dir / "demo_config.json")
45-
41+
checkpoint_path = kwargs.get("checkpoint", None)
4642
# Check if checkpoint_dir was provided for a sharded checkpoint.
4743
checkpoint_dir = kwargs.get("checkpoint_dir", None)
4844

45+
# Params file.
46+
params_path = kwargs.get("params", None)
47+
4948
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5049
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
5150
self.generate_full_logits = kwargs.get("generate_full_logits", False)
@@ -66,6 +65,7 @@ def __init__(self, **kwargs):
6665
# flake8: noqa: TOR102
6766
cps = []
6867
# Load sharded checkpoint.
68+
checkpoint = {}
6969
if checkpoint_dir is not None:
7070
# Load multiple checkpoint; ignore the single path.
7171
checkpoint_path = None
@@ -93,7 +93,7 @@ def __init__(self, **kwargs):
9393
# Do not duplicate layers shared between each checkpoint.
9494
checkpoint[key] = cps[0][key]
9595
# Load single checkpoint.
96-
else:
96+
elif checkpoint_path:
9797
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
9898

9999
# If given checkpoint is fairseq, convert to llama checkpoint.
@@ -122,8 +122,12 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
with open(params_path, "r") as f:
126-
params = json.loads(f.read())
125+
# Get optional params.
126+
params = {}
127+
if params_path:
128+
with open(params_path, "r") as f:
129+
params = json.loads(f.read())
130+
127131
output_prune_map = None
128132
if self.output_prune_map_path is not None:
129133
with open(self.output_prune_map_path, "r") as f:
@@ -170,7 +174,11 @@ def __init__(self, **kwargs):
170174
with torch.device("meta"):
171175
# Model itself is loaded in default dtype, fp32.
172176
self.model_ = Transformer(model_args)
173-
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
177+
# Get checkpoint dtype.
178+
if checkpoint:
179+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
180+
else:
181+
self.model_.checkpoint_dtype = None
174182

175183
if "int8" in str(checkpoint_path):
176184
print("Using int8 weight-only quantization!")
@@ -244,16 +252,19 @@ def __init__(self, **kwargs):
244252
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245253
# by default initialized to fp32. This is fine because every other supported type
246254
# losslessly converts to fp32, so we don't lose precision here.
247-
missing, unexpected = self.model_.load_state_dict(
248-
checkpoint,
249-
strict=False,
250-
assign=True,
251-
) # self.model_ = Transformer(gptconf)
255+
if checkpoint:
256+
missing, unexpected = self.model_.load_state_dict(
257+
checkpoint,
258+
strict=False,
259+
assign=True,
260+
) # self.model_ = Transformer(gptconf)
261+
else:
262+
print("Checkpoint not provided, defaulting to uninitialized weights.")
263+
self.model_.to_empty(device="cpu")
252264
except RuntimeError as e:
253265
print(
254-
"Could not load checkpoint into mode, defaulting to random uninitialized weights."
266+
f"Could not load checkpoint into mode and will default to uninitialized weights due to error: {e}."
255267
)
256-
print(f"Error: {e}")
257268
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
258269
self.model_.to_empty(device="cpu")
259270

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ModelArgs:
88
n_layers: int = 32
99
n_heads: int = 32
1010
n_kv_heads: Optional[int] = None
11-
vocab_size: int = -1 # defined later by tokenizer
11+
vocab_size: int = 512 # Arbitrary value, should be defined later by tokenizer.
1212
hidden_dim: Optional[int] = None
1313
head_dim: Optional[int] = None # Optional customized head_dim
1414
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2

examples/models/llava/export_llava.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,17 @@ def forward(self, input_pos, embeddings):
9898
dtype_override = DType.fp32
9999
parser = build_args_parser()
100100
args = parser.parse_args(
101-
["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
101+
[
102+
"-p",
103+
"params.json",
104+
"-X",
105+
"-qmode",
106+
"8da4w",
107+
"--group_size",
108+
"128",
109+
"--embedding-quantize",
110+
"4,32",
111+
]
102112
)
103113
quant_transform = get_quant_weight_transform(args, dtype_override, False)
104114
_, quantizers, _ = get_quantizer_and_quant_params(args)

0 commit comments

Comments
 (0)