Skip to content

Commit c8b4d27

Browse files
committed
Export only llama arg
1 parent 8f9fb7e commit c8b4d27

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,13 @@ def build_args_parser() -> argparse.ArgumentParser:
443443
default=None,
444444
help="path to the input pruning token mapping file (token_map.json)",
445445
)
446+
447+
parser.add_argument(
448+
"--export_only",
449+
default=False,
450+
action="store_true",
451+
help="If true, stops right after torch.export() and saves the exported model.",
452+
)
446453
return parser
447454

448455

@@ -587,9 +594,16 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
587594
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
588595

589596
# export_to_edge
590-
builder_exported_to_edge = (
597+
builder_exported = (
591598
_prepare_for_llama_export(modelname, args)
592599
.export()
600+
)
601+
602+
if args.export_only:
603+
exit()
604+
605+
builder_exported_to_edge = (
606+
builder_exported
593607
.pt2e_quantize(quantizers)
594608
.export_to_edge()
595609
)

extension/llm/export/builder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,22 +186,25 @@ def export(self) -> "LLMEdgeManager":
186186
# functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
187187
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
188188
# `Module`.
189-
self.pre_autograd_graph_module = torch.export.export(
189+
exported_module = torch.export.export(
190190
self.model,
191191
self.example_inputs,
192192
self.example_kwarg_inputs,
193193
dynamic_shapes=dynamic_shape,
194194
strict=True,
195-
).module()
195+
)
196196
else:
197197
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
198198
# `Module`.
199-
self.pre_autograd_graph_module = export_for_training(
199+
exported_module = export_for_training(
200200
self.model,
201201
self.example_inputs,
202202
kwargs=self.example_kwarg_inputs,
203203
dynamic_shapes=dynamic_shape,
204-
).module()
204+
)
205+
self.pre_autograd_graph_module = exported_module.module()
206+
if self.args.export_only:
207+
torch.export.save(exported_module, self.args.output_name)
205208

206209
return self
207210

0 commit comments

Comments
 (0)