@@ -89,7 +89,9 @@ def __init__(
8989 dynamic_shapes : Optional [Any ] = None ,
9090 ):
9191 self .model = model
92- self .pre_autograd_exported_program : Optional [ExportedProgram ] = None
92+ self .exported_program : Optional [ExportedProgram ] = None
93+ # Self.exported_program's pre-autograd graph module, for running
94+ # transform passes on the graph prior to torch.export().
9395 self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
9496 self .modelname = modelname
9597 self .max_seq_len = max_seq_len
@@ -184,7 +186,21 @@ def _get_edge_config(self) -> EdgeCompileConfig:
184186 )
185187 return edge_config
186188
187- def export (self ) -> "LLMEdgeManager" :
189+ def export (self , module : Optional [torch .nn .Module ] = None ) -> "LLMEdgeManager" :
190+ """
191+ Exports the model pre-autograd. This is not a full export, since it uses
192+ torch.export_for_training() to keep autograd-safe ops from getting decomposed.
193+ The full torch.export() if called later on during to_edge() or
194+ to_edge_transform_and_lower().
195+
196+ The optional `module` argument is included so that the user can re-export
197+ an already-exported module's ExportedProgram's graph module, to persiste
198+ the changes into a new ExportedProgram.
199+
200+ Args:
201+ module (Optional[torch.nn.Module]): module to export.
202+
203+ """
188204 dynamic_shape = self ._get_dynamic_shape ()
189205 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
190206 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -201,25 +217,30 @@ def export(self) -> "LLMEdgeManager":
201217 # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
202218 # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
203219 exported_module = torch .export .export (
204- self .model ,
220+ self .model if not module else module ,
205221 self .example_inputs ,
206222 self .example_kwarg_inputs ,
207223 dynamic_shapes = dynamic_shape ,
208224 strict = True ,
209225 )
210226 else :
211- logging .info ("Exporting with:" )
227+ if module :
228+ logging .info ("Re-exporting with:" )
229+ else :
230+ logging .info ("Exporting with:" )
212231 logging .info (f"inputs: { self .example_inputs } " )
213232 logging .info (f"kwargs: { self .example_kwarg_inputs } " )
214233 logging .info (f"dynamic shapes: { dynamic_shape } " )
215234 exported_module = export_for_training (
216- self .model ,
235+ self .model if not module else module ,
217236 self .example_inputs ,
218237 kwargs = self .example_kwarg_inputs ,
219238 dynamic_shapes = dynamic_shape ,
220239 )
221- # `Module`.
222- self .pre_autograd_exported_program = exported_module
240+ self .exported_program = exported_module
241+ # Need to store the graph module to record transformation passes.
242+ # Persisting those changes back to the ExportedProgram will require
243+ # an additional export().
223244 self .pre_autograd_graph_module = exported_module .module ()
224245 if hasattr (self .args , "export_only" ) and self .args .export_only :
225246 torch .export .save (exported_module , self .args .output_name )
@@ -382,7 +403,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
382403 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
383404 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
384405 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
385- if self .pre_autograd_graph_module is None :
406+ if self .exported_program is None :
386407 # Run export() if it didn't run
387408 self .export ()
388409
@@ -394,9 +415,12 @@ def export_to_edge(self) -> "LLMEdgeManager":
394415 return_value = False ,
395416 )
396417
418+ # Prior to export, persist the changes to the pre autograd
419+ # graph module back to the source-of-truth ExportedProgram.
420+ self .export (self .pre_autograd_graph_module )
397421 with override_export_behaviour :
398422 self .edge_manager = export_to_edge (
399- self .pre_autograd_graph_module , # pyre-fixme[6]
423+ self .exported_program . module () , # pyre-fixme[6]
400424 self .example_inputs ,
401425 example_kwarg_inputs = self .example_kwarg_inputs ,
402426 dynamic_shapes = dynamic_shape ,
@@ -441,9 +465,14 @@ def to_edge_transform_and_lower(
441465 ) -> "LLMEdgeManager" :
442466 if partitioners is None :
443467 logging .info ("No partitioner provided, skipping backend lowering..." )
468+
469+ # Prior to export, persist the changes to the pre autograd
470+ # graph module back to the source-of-truth ExportedProgram.
471+ self .export (self .pre_autograd_graph_module )
472+
444473 edge_config = self ._get_edge_config ()
445474 self .edge_manager = to_edge_transform_and_lower (
446- self .pre_autograd_exported_program ,
475+ self .exported_program ,
447476 partitioner = partitioners ,
448477 compile_config = edge_config ,
449478 constant_methods = self .metadata ,
0 commit comments