@@ -89,7 +89,10 @@ def __init__(
8989 dynamic_shapes : Optional [Any ] = None ,
9090 ):
9191 self .model = model
92- self .pre_autograd_exported_program : Optional [ExportedProgram ] = None
92+ # Note: treat this as the source of truth for the result of
93+ # torch.export'ing a model. If the overall ExportedProgram is needed,
94+ # make sure to re-export this graph module to persist any changes. See
95+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
9396 self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
9497 self .modelname = modelname
9598 self .max_seq_len = max_seq_len
@@ -184,7 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
184187 )
185188 return edge_config
186189
187- def export (self ) -> "LLMEdgeManager" :
190+ def _export (self , module : Optional [ torch . nn . Module ] = None ) -> ExportedProgram :
188191 dynamic_shape = self ._get_dynamic_shape ()
189192 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
190193 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -201,29 +204,42 @@ def export(self) -> "LLMEdgeManager":
201204 # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
202205 # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
203206 exported_module = torch .export .export (
204- self .model ,
207+ self .model if not module else module ,
205208 self .example_inputs ,
206209 self .example_kwarg_inputs ,
207210 dynamic_shapes = dynamic_shape ,
208211 strict = True ,
209212 )
210213 else :
211- logging .info ("Exporting with:" )
214+ if module :
215+ logging .info ("Re-exporting with:" )
216+ else :
217+ logging .info ("Exporting with:" )
212218 logging .info (f"inputs: { self .example_inputs } " )
213219 logging .info (f"kwargs: { self .example_kwarg_inputs } " )
214220 logging .info (f"dynamic shapes: { dynamic_shape } " )
215221 exported_module = export_for_training (
216- self .model ,
222+ self .model if not module else module ,
217223 self .example_inputs ,
218224 kwargs = self .example_kwarg_inputs ,
219225 dynamic_shapes = dynamic_shape ,
220226 )
221- # `Module`.
222- self .pre_autograd_exported_program = exported_module
223- self .pre_autograd_graph_module = exported_module .module ()
224- if hasattr (self .args , "export_only" ) and self .args .export_only :
225- torch .export .save (exported_module , self .args .output_name )
227+ return exported_module
226228
229+ def export (self ) -> "LLMEdgeManager" :
230+ """
231+ Exports the model pre-autograd. This is not a full export, since it uses
232+ torch.export_for_training() to keep autograd-safe ops from getting decomposed.
233+ The full torch.export() if called later on during to_edge() or
234+ to_edge_transform_and_lower().
235+ """
236+ exported_module = self ._export ()
237+ # Need to store the graph module to record transformation passes.
238+ # Persisting those changes back to an ExportedProgram will require
239+ # an additional export().
240+ self .pre_autograd_graph_module = exported_module .module ()
241+ if hasattr (self .args , "export_only" ) and self .args .export_only :
242+ torch .export .save (exported_module , self .args .output_name )
227243 return self
228244
229245 def run_canonical_optimizations (self ):
@@ -441,9 +457,13 @@ def to_edge_transform_and_lower(
441457 ) -> "LLMEdgeManager" :
442458 if partitioners is None :
443459 logging .info ("No partitioner provided, skipping backend lowering..." )
460+
461+ # Need to construct ExportedProgram with the new transformed graph module.
462+ exported_module = self ._export (self .pre_autograd_graph_module )
463+
444464 edge_config = self ._get_edge_config ()
445465 self .edge_manager = to_edge_transform_and_lower (
446- self . pre_autograd_exported_program ,
466+ exported_module ,
447467 partitioner = partitioners ,
448468 compile_config = edge_config ,
449469 constant_methods = self .metadata ,
0 commit comments