diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 219cc71ded1..06238573096 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -879,16 +879,21 @@ def _to_edge_and_lower_llama_xnnpack( builder_exported.generate_etrecord = True builder = builder_exported.pt2e_quantize(quantizers) + + # re-export required here? + # run decomps + builder = builder.run_decompositions() + # tag EP if gen_tag_fn is not None: from executorch.exir.passes.external_constants_pass import ( - delegate_external_constants_pass_unlifted, + delegate_external_constants_pass_lifted, ) assert ( builder_exported.pre_autograd_graph_module is not None ), "pre_autograd_graph_module shouldn't be None here" - delegate_external_constants_pass_unlifted( - module=builder_exported.pre_autograd_graph_module, + delegate_external_constants_pass_lifted( + ep=builder_exported.exported_module, gen_tag_fn=gen_tag_fn, ) diff --git a/exir/passes/external_constants_pass.py b/exir/passes/external_constants_pass.py index ba1d85adb4f..e012a7bdb79 100644 --- a/exir/passes/external_constants_pass.py +++ b/exir/passes/external_constants_pass.py @@ -111,3 +111,23 @@ def delegate_external_constants_pass_unlifted( node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) mutated = True return PassResult(module, mutated) + +def delegate_external_constants_pass_lifted( + ep: ExportedProgram, + gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None, + ) -> PassResult: + """ + Tag constants in an ExportedProgram for external storage. + Works on the lifted graph directly, no re-export needed. + """ + mutated = False + gm = ep.graph_module + + for node in gm.graph.nodes: + if node.op == "placeholder" and is_param_node(ep, node): + if gen_tag_fn is not None: + node.meta.setdefault("custom", {}) + node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node) + mutated = True + + return PassResult(gm, mutated) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 675c0179ebb..5843c46ed2c 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -20,7 +20,7 @@ DuplicateDynamicQuantChainPass, ) from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass -from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower +from executorch.exir import EdgeProgramManager, to_edge, to_edge_transform_and_lower from executorch.exir.backend.partitioner import Partitioner from executorch.exir.backend.utils import format_delegated_graph @@ -125,6 +125,7 @@ def __init__( # make sure to re-export this graph module to persist any changes. See # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921 self.pre_autograd_graph_module: Optional[torch.nn.Module] = None + self.exported_module: Optional[torch.export.ExportedProgram] = None self.edge_manager: Optional[EdgeProgramManager] = None self.canonical_passes = [ RemoveRedundantTransposes() @@ -235,17 +236,23 @@ def export(self) -> "LLMEdgeManager": The full torch.export() if called later on during to_edge() or to_edge_transform_and_lower(). """ - exported_module = self._export() + self.exported_module = self._export() # Need to store the graph module to record transformation passes. # Persisting those changes back to an ExportedProgram will require # an additional export(). - self.pre_autograd_graph_module = exported_module.module() + self.pre_autograd_graph_module = self.exported_module.module() if self.save_exported_program: export_output = f"{self.modelname}.pt2" logging.info(f"Saving torch.export() result to {export_output}") torch.export.save(exported_module, export_output) return self + def run_decompositions(self) -> "LLMEdgeManager": + # Re-export to capture any pending changes to pre_autograd_graph_module + self.exported_module = self._export(self.pre_autograd_graph_module) + self.exported_module = self.exported_module.run_decompositions({}) + return self + def run_canonical_optimizations(self): """ Run canonical optimizations (at the moment removing redundant permutes) on the model. @@ -256,6 +263,8 @@ def run_canonical_optimizations(self): res = pass_instance(self.pre_autograd_graph_module) assert res.graph_module is not None, "Pass returned None" self.pre_autograd_graph_module = res.graph_module + # Re-export to capture changes to pre_autograd_graph_module + self.exported_module = self._export(self.pre_autograd_graph_module) def pt2e_calibrate( self, @@ -389,6 +398,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage m = convert_pt2e(m) DuplicateDynamicQuantChainPass()(m) self.pre_autograd_graph_module = m + # Re-export to capture changes to pre_autograd_graph_module + self.exported_module = self._export(self.pre_autograd_graph_module) return self else: logging.info("No quantizer provided, passing...") @@ -398,7 +409,6 @@ def export_to_edge(self) -> "LLMEdgeManager": """ Export the model to Edge dialect and retrieve a LLMEdgeManager. """ - dynamic_shape = self._get_dynamic_shape() edge_config = self._get_edge_config() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing @@ -408,17 +418,11 @@ def export_to_edge(self) -> "LLMEdgeManager": # Run export() if it didn't run self.export() - override_export_behaviour = contextlib.nullcontext() - with override_export_behaviour: - self.edge_manager = export_to_edge( - self.pre_autograd_graph_module, # pyre-fixme[6] - self.example_inputs, - example_kwarg_inputs=self.example_kwarg_inputs, - dynamic_shapes=dynamic_shape, - edge_constant_methods=self.metadata, - edge_compile_config=edge_config, - verbose=self.verbose, - ) + self.edge_manager = to_edge( + self.exported_module, + constant_methods=self.metadata, + compile_config=edge_config, + ) return self def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager": @@ -457,12 +461,9 @@ def to_edge_transform_and_lower( if partitioners is None: logging.info("No partitioner provided, skipping backend lowering...") - # Need to construct ExportedProgram with the new transformed graph module. - exported_module = self._export(self.pre_autograd_graph_module) - edge_config = self._get_edge_config() self.edge_manager = to_edge_transform_and_lower( - exported_module, + self.exported_module, partitioner=partitioners, compile_config=edge_config, constant_methods=self.metadata,