From c8b4d272b0413edf455534281a05f7f48b04a7e8 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 6 Nov 2024 10:10:48 -0800 Subject: [PATCH 1/3] Export only llama arg --- examples/models/llama/export_llama_lib.py | 16 +++++++++++++++- extension/llm/export/builder.py | 11 +++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index a0b44fb9652..350f4584157 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -443,6 +443,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=None, help="path to the input pruning token mapping file (token_map.json)", ) + + parser.add_argument( + "--export_only", + default=False, + action="store_true", + help="If true, stops right after torch.export() and saves the exported model.", + ) return parser @@ -587,9 +594,16 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) # export_to_edge - builder_exported_to_edge = ( + builder_exported = ( _prepare_for_llama_export(modelname, args) .export() + ) + + if args.export_only: + exit() + + builder_exported_to_edge = ( + builder_exported .pt2e_quantize(quantizers) .export_to_edge() ) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bd12c374b51..50b1cad560a 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -186,22 +186,25 @@ def export(self) -> "LLMEdgeManager": # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. - self.pre_autograd_graph_module = torch.export.export( + exported_module = torch.export.export( self.model, self.example_inputs, self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, strict=True, - ).module() + ) else: # pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as # `Module`. - self.pre_autograd_graph_module = export_for_training( + exported_module = export_for_training( self.model, self.example_inputs, kwargs=self.example_kwarg_inputs, dynamic_shapes=dynamic_shape, - ).module() + ) + self.pre_autograd_graph_module = exported_module.module() + if self.args.export_only: + torch.export.save(exported_module, self.args.output_name) return self From 2230f91e9c0944c643ce06f92c809ca35335e428 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 6 Nov 2024 10:14:25 -0800 Subject: [PATCH 2/3] Lint --- examples/models/llama/export_llama_lib.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 350f4584157..825cf923e7c 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -594,19 +594,14 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) # export_to_edge - builder_exported = ( - _prepare_for_llama_export(modelname, args) - .export() - ) + builder_exported = _prepare_for_llama_export(modelname, args).export() if args.export_only: exit() - builder_exported_to_edge = ( - builder_exported - .pt2e_quantize(quantizers) - .export_to_edge() - ) + builder_exported_to_edge = builder_exported.pt2e_quantize( + quantizers + ).export_to_edge() modelname = builder_exported_to_edge.modelname From a6a716239d408ab40fe2f1df79a9e5b4510fca0a Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Wed, 6 Nov 2024 20:02:14 -0500 Subject: [PATCH 3/3] Fix test --- extension/llm/export/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 50b1cad560a..311e788797f 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -203,7 +203,7 @@ def export(self) -> "LLMEdgeManager": dynamic_shapes=dynamic_shape, ) self.pre_autograd_graph_module = exported_module.module() - if self.args.export_only: + if hasattr(self.args, "export_only") and self.args.export_only: torch.export.save(exported_module, self.args.output_name) return self