2020 DuplicateDynamicQuantChainPass ,
2121)
2222from executorch .backends .xnnpack ._passes .convert_to_linear import ConvertToLinearPass
23- from executorch .exir import EdgeProgramManager , to_edge_transform_and_lower
23+ from executorch .exir import EdgeProgramManager , to_edge , to_edge_transform_and_lower
2424from executorch .exir .backend .partitioner import Partitioner
2525
2626from executorch .exir .backend .utils import format_delegated_graph
@@ -125,6 +125,7 @@ def __init__(
125125 # make sure to re-export this graph module to persist any changes. See
126126 # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
127127 self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
128+ self .exported_module : Optional [torch .export .ExportedProgram ] = None
128129 self .edge_manager : Optional [EdgeProgramManager ] = None
129130 self .canonical_passes = [
130131 RemoveRedundantTransposes ()
@@ -235,17 +236,23 @@ def export(self) -> "LLMEdgeManager":
235236 The full torch.export() if called later on during to_edge() or
236237 to_edge_transform_and_lower().
237238 """
238- exported_module = self ._export ()
239+ self . exported_module = self ._export ()
239240 # Need to store the graph module to record transformation passes.
240241 # Persisting those changes back to an ExportedProgram will require
241242 # an additional export().
242- self .pre_autograd_graph_module = exported_module .module ()
243+ self .pre_autograd_graph_module = self . exported_module .module ()
243244 if self .save_exported_program :
244245 export_output = f"{ self .modelname } .pt2"
245246 logging .info (f"Saving torch.export() result to { export_output } " )
246247 torch .export .save (exported_module , export_output )
247248 return self
248249
250+ def run_decompositions (self ) -> "LLMEdgeManager" :
251+ # Re-export to capture any pending changes to pre_autograd_graph_module
252+ self .exported_module = self ._export (self .pre_autograd_graph_module )
253+ self .exported_module = self .exported_module .run_decompositions ({})
254+ return self
255+
249256 def run_canonical_optimizations (self ):
250257 """
251258 Run canonical optimizations (at the moment removing redundant permutes) on the model.
@@ -256,6 +263,8 @@ def run_canonical_optimizations(self):
256263 res = pass_instance (self .pre_autograd_graph_module )
257264 assert res .graph_module is not None , "Pass returned None"
258265 self .pre_autograd_graph_module = res .graph_module
266+ # Re-export to capture changes to pre_autograd_graph_module
267+ self .exported_module = self ._export (self .pre_autograd_graph_module )
259268
260269 def pt2e_calibrate (
261270 self ,
@@ -389,6 +398,8 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
389398 m = convert_pt2e (m )
390399 DuplicateDynamicQuantChainPass ()(m )
391400 self .pre_autograd_graph_module = m
401+ # Re-export to capture changes to pre_autograd_graph_module
402+ self .exported_module = self ._export (self .pre_autograd_graph_module )
392403 return self
393404 else :
394405 logging .info ("No quantizer provided, passing..." )
@@ -398,7 +409,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
398409 """
399410 Export the model to Edge dialect and retrieve a LLMEdgeManager.
400411 """
401- dynamic_shape = self ._get_dynamic_shape ()
402412 edge_config = self ._get_edge_config ()
403413
404414 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
@@ -408,17 +418,11 @@ def export_to_edge(self) -> "LLMEdgeManager":
408418 # Run export() if it didn't run
409419 self .export ()
410420
411- override_export_behaviour = contextlib .nullcontext ()
412- with override_export_behaviour :
413- self .edge_manager = export_to_edge (
414- self .pre_autograd_graph_module , # pyre-fixme[6]
415- self .example_inputs ,
416- example_kwarg_inputs = self .example_kwarg_inputs ,
417- dynamic_shapes = dynamic_shape ,
418- edge_constant_methods = self .metadata ,
419- edge_compile_config = edge_config ,
420- verbose = self .verbose ,
421- )
421+ self .edge_manager = to_edge (
422+ self .exported_module ,
423+ constant_methods = self .metadata ,
424+ compile_config = edge_config ,
425+ )
422426 return self
423427
424428 def to_backend (self , partitioners : Optional [List [Partitioner ]]) -> "LLMEdgeManager" :
@@ -457,12 +461,9 @@ def to_edge_transform_and_lower(
457461 if partitioners is None :
458462 logging .info ("No partitioner provided, skipping backend lowering..." )
459463
460- # Need to construct ExportedProgram with the new transformed graph module.
461- exported_module = self ._export (self .pre_autograd_graph_module )
462-
463464 edge_config = self ._get_edge_config ()
464465 self .edge_manager = to_edge_transform_and_lower (
465- exported_module ,
466+ self . exported_module ,
466467 partitioner = partitioners ,
467468 compile_config = edge_config ,
468469 constant_methods = self .metadata ,
0 commit comments