1414import logging
1515from enum import Enum
1616from typing import Any , Callable , Dict , List , Optional , Tuple
17- from unittest .mock import patch
1817
1918import torch
2019from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
@@ -96,7 +95,6 @@ def __init__(
9695 verbose : bool = False ,
9796 metadata : Optional [dict ] = None ,
9897 dynamic_shapes : Optional [Any ] = None ,
99- use_legacy_export : bool = False ,
10098 save_exported_program : bool = False ,
10199 ):
102100 # Store necessary constructor arguments.
@@ -117,7 +115,6 @@ def __init__(
117115 self .verbose = verbose
118116 self .metadata = metadata
119117 self .dynamic_shapes = dynamic_shapes
120- self .use_legacy_export = use_legacy_export
121118 self .save_exported_program = save_exported_program
122119
123120 # Note: treat this as the source of truth for the result of
@@ -228,39 +225,20 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
228225 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
229226 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
230227 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
231- if self .use_legacy_export :
232- # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
233- # See issue: https://github.com/pytorch/executorch/issues/7373
234-
235- with patch .object (
236- torch ._utils_internal ,
237- "export_training_ir_rollout_check" ,
238- return_value = False ,
239- ):
240- # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
241- # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
242- exported_module = torch .export .export (
243- self .model if not module else module ,
244- self .example_inputs ,
245- self .example_kwarg_inputs ,
246- dynamic_shapes = dynamic_shape ,
247- strict = True ,
248- )
228+ if module :
229+ logging .info ("Re-exporting with:" )
249230 else :
250- if module :
251- logging .info ("Re-exporting with:" )
252- else :
253- logging .info ("Exporting with:" )
254- logging .info (f"inputs: { self .example_inputs } " )
255- logging .info (f"kwargs: { self .example_kwarg_inputs } " )
256- logging .info (f"dynamic shapes: { dynamic_shape } " )
257- exported_module = export_for_training (
258- self .model if not module else module ,
259- self .example_inputs ,
260- kwargs = self .example_kwarg_inputs ,
261- dynamic_shapes = dynamic_shape ,
262- strict = True ,
263- )
231+ logging .info ("Exporting with:" )
232+ logging .info (f"inputs: { self .example_inputs } " )
233+ logging .info (f"kwargs: { self .example_kwarg_inputs } " )
234+ logging .info (f"dynamic shapes: { dynamic_shape } " )
235+ exported_module = export_for_training (
236+ self .model if not module else module ,
237+ self .example_inputs ,
238+ kwargs = self .example_kwarg_inputs ,
239+ dynamic_shapes = dynamic_shape ,
240+ strict = True ,
241+ )
264242 return exported_module
265243
266244 def export (self ) -> "LLMEdgeManager" :
@@ -446,13 +424,6 @@ def export_to_edge(self) -> "LLMEdgeManager":
446424 self .export ()
447425
448426 override_export_behaviour = contextlib .nullcontext ()
449- if self .use_legacy_export :
450- override_export_behaviour = patch .object (
451- torch ._utils_internal ,
452- "export_training_ir_rollout_check" ,
453- return_value = False ,
454- )
455-
456427 with override_export_behaviour :
457428 self .edge_manager = export_to_edge (
458429 self .pre_autograd_graph_module , # pyre-fixme[6]
0 commit comments