88
99import logging
1010from pathlib import Path
11- from typing import Callable , cast , Optional
11+ from typing import Optional
1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
3232 ExecutorchBackendConfig ,
3333 ExecutorchProgramManager ,
3434)
35- from executorch .exir .pass_base import PassResult
3635from executorch .exir .passes import ToOutVarPass
3736from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837from executorch .exir .program ._program import to_edge_with_preserved_ops
4140from torch .export .exported_program import ExportedProgram
4241from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
4342
44- from .passes import get_cadence_passes
43+ from .passes import apply_exir_ops_passes , apply_torch_ops_passes
4544
4645from .utils import print_ops_info
4746
@@ -262,14 +261,20 @@ def export_to_edge(
262261 inputs : tuple [object , ...],
263262 dump_graphs : bool = False ,
264263 constant_methods : Optional [dict [str , object ]] = None ,
264+ core_aten_exceptions : Optional [list [torch ._ops .OpOverload ]] = None ,
265265) -> EdgeProgramManager :
266266 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
267267
268268 # Export the model into an ExportedProgram.
269269 expo_program = trace (model , inputs )
270270
271+ # Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+ expo_program = apply_torch_ops_passes (expo_program )
273+
271274 # Lower the model to edge IR.
272- edge_prog_manager = _lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
275+ edge_prog_manager = _lower_ep_to_edge (
276+ expo_program , dump_graphs , constant_methods , core_aten_exceptions
277+ )
273278
274279 return edge_prog_manager
275280
@@ -311,14 +316,7 @@ def _lower_ep_to_cadence(
311316 Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
312317 """
313318 edge_prog_manager = _lower_ep_to_edge (program , dump_graphs = dump_graphs )
314- cadence_passes = get_cadence_passes (opt_level )
315-
316- # Run a couple required passes for quant/dequant ops
317- cadence_prog_manager = edge_prog_manager .transform (
318- cast (
319- list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
320- )
321- )
319+ cadence_prog_manager = apply_exir_ops_passes (opt_level , edge_prog_manager )
322320 return cadence_prog_manager
323321
324322
@@ -329,14 +327,7 @@ def export_to_cadence(
329327 opt_level : int = 1 ,
330328) -> EdgeProgramManager :
331329 edge_prog_manager = export_to_edge (model , inputs , dump_graphs = dump_graphs )
332- cadence_passes = get_cadence_passes (opt_level )
333-
334- # Run a couple required passes for quant/dequant ops
335- cadence_prog_manager = edge_prog_manager .transform (
336- cast (
337- list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
338- )
339- )
330+ cadence_prog_manager = apply_exir_ops_passes (opt_level , edge_prog_manager )
340331 return cadence_prog_manager
341332
342333
@@ -373,15 +364,8 @@ def export_to_executorch_gen_etrecord(
373364 memory_config : Optional [MemoryConfig ] = None ,
374365 dump_graphs : bool = False ,
375366) -> ExecutorchProgramManager :
376- cadence_passes = get_cadence_passes (opt_level )
377367 edge_prog_manager = export_to_edge (model , inputs , dump_graphs )
378-
379- # Run a couple required passes for quant/dequant ops
380- cadence_prog_manager = edge_prog_manager .transform (
381- cast (
382- list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
383- )
384- )
368+ cadence_prog_manager = apply_exir_ops_passes (opt_level , edge_prog_manager )
385369
386370 # Print some information to terminal
387371 print_ops_info (
0 commit comments