@@ -96,7 +96,6 @@ def __init__(
9696 verbose : bool = False ,
9797 metadata : Optional [dict ] = None ,
9898 dynamic_shapes : Optional [Any ] = None ,
99- use_legacy_export : bool = False ,
10099 save_exported_program : bool = False ,
101100 ):
102101 # Store necessary constructor arguments.
@@ -117,7 +116,6 @@ def __init__(
117116 self .verbose = verbose
118117 self .metadata = metadata
119118 self .dynamic_shapes = dynamic_shapes
120- self .use_legacy_export = use_legacy_export
121119 self .save_exported_program = save_exported_program
122120
123121 # Note: treat this as the source of truth for the result of
@@ -228,39 +226,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
228226 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
229227 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
230228 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
231- if self .use_legacy_export :
232- # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
233- # See issue: https://github.com/pytorch/executorch/issues/7373
234-
235- with patch .object (
236- torch ._utils_internal ,
237- "export_training_ir_rollout_check" ,
238- return_value = False ,
239- ):
240- # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
241- # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
242- exported_module = torch .export .export (
243- self .model if not module else module ,
244- self .example_inputs ,
245- self .example_kwarg_inputs ,
246- dynamic_shapes = dynamic_shape ,
247- strict = True ,
248- )
229+ if module :
230+ logging .info ("Re-exporting with:" )
249231 else :
250- if module :
251- logging .info ("Re-exporting with:" )
252- else :
253- logging .info ("Exporting with:" )
254- logging .info (f"inputs: { self .example_inputs } " )
255- logging .info (f"kwargs: { self .example_kwarg_inputs } " )
256- logging .info (f"dynamic shapes: { dynamic_shape } " )
257- exported_module = export_for_training (
258- self .model if not module else module ,
259- self .example_inputs ,
260- kwargs = self .example_kwarg_inputs ,
261- dynamic_shapes = dynamic_shape ,
262- strict = True ,
263- )
232+ logging .info ("Exporting with:" )
233+ logging .info (f"inputs: { self .example_inputs } " )
234+ logging .info (f"kwargs: { self .example_kwarg_inputs } " )
235+ logging .info (f"dynamic shapes: { dynamic_shape } " )
236+ exported_module = export_for_training (
237+ self .model if not module else module ,
238+ self .example_inputs ,
239+ kwargs = self .example_kwarg_inputs ,
240+ dynamic_shapes = dynamic_shape ,
241+ strict = True ,
242+ )
264243 return exported_module
265244
266245 def export (self ) -> "LLMEdgeManager" :
@@ -446,13 +425,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
446425 self .export ()
447426
448427 override_export_behaviour = contextlib .nullcontext ()
449- if self .use_legacy_export :
450- override_export_behaviour = patch .object (
451- torch ._utils_internal ,
452- "export_training_ir_rollout_check" ,
453- return_value = False ,
454- )
455-
456428 with override_export_behaviour :
457429 self .edge_manager = export_to_edge (
458430 self .pre_autograd_graph_module , # pyre-fixme[6]
0 commit comments