@@ -89,9 +89,10 @@ def __init__(
8989 dynamic_shapes : Optional [Any ] = None ,
9090 ):
9191 self .model = model
92- self .exported_program : Optional [ExportedProgram ] = None
93- # Self.exported_program's pre-autograd graph module, for running
94- # transform passes on the graph prior to torch.export().
92+ # Note: treat this as the source of truth for the result of
93+ # torch.export'ing a model. If the overall ExportedProgram is needed,
94+ # make sure to re-export this graph module to persist any changes. See
95+ # https://github.com/pytorch/pytorch/blob/main/torch/export/exported_program.py#L921
9596 self .pre_autograd_graph_module : Optional [torch .nn .Module ] = None
9697 self .modelname = modelname
9798 self .max_seq_len = max_seq_len
@@ -186,21 +187,7 @@ def _get_edge_config(self) -> EdgeCompileConfig:
186187 )
187188 return edge_config
188189
189- def export (self , module : Optional [torch .nn .Module ] = None ) -> "LLMEdgeManager" :
190- """
191- Exports the model pre-autograd. This is not a full export, since it uses
192- torch.export_for_training() to keep autograd-safe ops from getting decomposed.
193- The full torch.export() if called later on during to_edge() or
194- to_edge_transform_and_lower().
195-
196- The optional `module` argument is included so that the user can re-export
197- an already-exported module's ExportedProgram's graph module, to persiste
198- the changes into a new ExportedProgram.
199-
200- Args:
201- module (Optional[torch.nn.Module]): module to export.
202-
203- """
190+ def _export (self , module : Optional [torch .nn .Module ] = None ) -> ExportedProgram :
204191 dynamic_shape = self ._get_dynamic_shape ()
205192 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
206193 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
@@ -237,14 +224,22 @@ def export(self, module: Optional[torch.nn.Module] = None) -> "LLMEdgeManager":
237224 kwargs = self .example_kwarg_inputs ,
238225 dynamic_shapes = dynamic_shape ,
239226 )
240- self .exported_program = exported_module
241- # Need to store the graph module to record transformation passes.
242- # Persisting those changes back to the ExportedProgram will require
243- # an additional export().
244- self .pre_autograd_graph_module = exported_module .module ()
245- if hasattr (self .args , "export_only" ) and self .args .export_only :
246- torch .export .save (exported_module , self .args .output_name )
227+ return exported_module
247228
229+ def export (self ) -> "LLMEdgeManager" :
230+ """
231+ Exports the model pre-autograd. This is not a full export, since it uses
232+ torch.export_for_training() to keep autograd-safe ops from getting decomposed.
233+ The full torch.export() if called later on during to_edge() or
234+ to_edge_transform_and_lower().
235+ """
236+ exported_module = self ._export ()
237+ # Need to store the graph module to record transformation passes.
238+ # Persisting those changes back to an ExportedProgram will require
239+ # an additional export().
240+ self .pre_autograd_graph_module = exported_module .module ()
241+ if hasattr (self .args , "export_only" ) and self .args .export_only :
242+ torch .export .save (exported_module , self .args .output_name )
248243 return self
249244
250245 def run_canonical_optimizations (self ):
@@ -403,7 +398,7 @@ def export_to_edge(self) -> "LLMEdgeManager":
403398 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
404399 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
405400 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
406- if self .exported_program is None :
401+ if self .pre_autograd_graph_module is None :
407402 # Run export() if it didn't run
408403 self .export ()
409404
@@ -415,12 +410,9 @@ def export_to_edge(self) -> "LLMEdgeManager":
415410 return_value = False ,
416411 )
417412
418- # Prior to export, persist the changes to the pre autograd
419- # graph module back to the source-of-truth ExportedProgram.
420- self .export (self .pre_autograd_graph_module )
421413 with override_export_behaviour :
422414 self .edge_manager = export_to_edge (
423- self .exported_program . module () , # pyre-fixme[6]
415+ self .pre_autograd_graph_module , # pyre-fixme[6]
424416 self .example_inputs ,
425417 example_kwarg_inputs = self .example_kwarg_inputs ,
426418 dynamic_shapes = dynamic_shape ,
@@ -466,13 +458,12 @@ def to_edge_transform_and_lower(
466458 if partitioners is None :
467459 logging .info ("No partitioner provided, skipping backend lowering..." )
468460
469- # Prior to export, persist the changes to the pre autograd
470- # graph module back to the source-of-truth ExportedProgram.
471- self .export (self .pre_autograd_graph_module )
461+ # Need to construct ExportedProgram with the new transformed graph module.
462+ exported_module = self ._export (self .pre_autograd_graph_module )
472463
473464 edge_config = self ._get_edge_config ()
474465 self .edge_manager = to_edge_transform_and_lower (
475- self . exported_program ,
466+ exported_module ,
476467 partitioner = partitioners ,
477468 compile_config = edge_config ,
478469 constant_methods = self .metadata ,
0 commit comments