1010
1111# pyre-unsafe
1212
13+ import contextlib
1314import logging
1415from enum import Enum
1516from typing import Any , Callable , Dict , List , Optional , Tuple
17+ from unittest .mock import patch
1618
1719import torch
1820from executorch .backends .transforms .duplicate_dynamic_quant_chain import (
@@ -94,6 +96,7 @@ def __init__(
9496 verbose : bool = False ,
9597 metadata : Optional [dict ] = None ,
9698 dynamic_shapes : Optional [Any ] = None ,
99+ use_legacy_export : bool = False ,
97100 save_exported_program : bool = False ,
98101 ):
99102 # Store necessary constructor arguments.
@@ -114,6 +117,7 @@ def __init__(
114117 self .verbose = verbose
115118 self .metadata = metadata
116119 self .dynamic_shapes = dynamic_shapes
120+ self .use_legacy_export = use_legacy_export
117121 self .save_exported_program = save_exported_program
118122
119123 # Note: treat this as the source of truth for the result of
@@ -225,20 +229,39 @@ def _export(self, module: Optional[torch.nn.Module] = None) -> ExportedProgram:
225229 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
226230 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
227231 with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
228- if module :
229- logging .info ("Re-exporting with:" )
232+ if self .use_legacy_export :
233+ # TODO: for use cases such as qnn, which does not work with new, non-functional export IR.
234+ # See issue: https://github.com/pytorch/executorch/issues/7373
235+
236+ with patch .object (
237+ torch ._utils_internal ,
238+ "export_training_ir_rollout_check" ,
239+ return_value = False ,
240+ ):
241+ # TODO: this is temporary and export_for_training doesn't work with qnn either. We need a
242+ # functional graph. See issue https://github.com/pytorch/executorch/pull/4627 for more details
243+ exported_module = torch .export .export (
244+ self .model if not module else module ,
245+ self .example_inputs ,
246+ self .example_kwarg_inputs ,
247+ dynamic_shapes = dynamic_shape ,
248+ strict = True ,
249+ )
230250 else :
231- logging .info ("Exporting with:" )
232- logging .info (f"inputs: { self .example_inputs } " )
233- logging .info (f"kwargs: { self .example_kwarg_inputs } " )
234- logging .info (f"dynamic shapes: { dynamic_shape } " )
235- exported_module = export_for_training (
236- self .model if not module else module ,
237- self .example_inputs ,
238- kwargs = self .example_kwarg_inputs ,
239- dynamic_shapes = dynamic_shape ,
240- strict = True ,
241- )
251+ if module :
252+ logging .info ("Re-exporting with:" )
253+ else :
254+ logging .info ("Exporting with:" )
255+ logging .info (f"inputs: { self .example_inputs } " )
256+ logging .info (f"kwargs: { self .example_kwarg_inputs } " )
257+ logging .info (f"dynamic shapes: { dynamic_shape } " )
258+ exported_module = export_for_training (
259+ self .model if not module else module ,
260+ self .example_inputs ,
261+ kwargs = self .example_kwarg_inputs ,
262+ dynamic_shapes = dynamic_shape ,
263+ strict = True ,
264+ )
242265 return exported_module
243266
244267 def export (self ) -> "LLMEdgeManager" :
@@ -423,15 +446,24 @@ def export_to_edge(self) -> "LLMEdgeManager":
423446 # Run export() if it didn't run
424447 self .export ()
425448
426- self .edge_manager = export_to_edge (
427- self .pre_autograd_graph_module , # pyre-fixme[6]
428- self .example_inputs ,
429- example_kwarg_inputs = self .example_kwarg_inputs ,
430- dynamic_shapes = dynamic_shape ,
431- edge_constant_methods = self .metadata ,
432- edge_compile_config = edge_config ,
433- verbose = self .verbose ,
434- )
449+ override_export_behaviour = contextlib .nullcontext ()
450+ if self .use_legacy_export :
451+ override_export_behaviour = patch .object (
452+ torch ._utils_internal ,
453+ "export_training_ir_rollout_check" ,
454+ return_value = False ,
455+ )
456+
457+ with override_export_behaviour :
458+ self .edge_manager = export_to_edge (
459+ self .pre_autograd_graph_module , # pyre-fixme[6]
460+ self .example_inputs ,
461+ example_kwarg_inputs = self .example_kwarg_inputs ,
462+ dynamic_shapes = dynamic_shape ,
463+ edge_constant_methods = self .metadata ,
464+ edge_compile_config = edge_config ,
465+ verbose = self .verbose ,
466+ )
435467 return self
436468
437469 def to_backend (self , partitioners : Optional [List [Partitioner ]]) -> "LLMEdgeManager" :
0 commit comments