88
99import logging
1010from pathlib import Path
11- from typing import Callable , cast , Optional
11+ from typing import Optional
1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
3232 ExecutorchBackendConfig ,
3333 ExecutorchProgramManager ,
3434)
35- from executorch .exir .pass_base import PassResult
3635from executorch .exir .passes import ToOutVarPass
3736from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837from executorch .exir .program ._program import to_edge_with_preserved_ops
4140from torch .export .exported_program import ExportedProgram
4241from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
4342
44- from .passes import get_cadence_passes
43+ from .passes import apply_exir_ops_passes , apply_torch_ops_passes
4544
4645from .utils import print_ops_info
4746
@@ -210,6 +209,21 @@ def quantize_pt2(
210209 return program
211210
212211
212+ TO_EDGE_OP_EXCEPTION_LIST : list [torch ._ops .OpOverload ] = [
213+ torch .ops .aten ._linalg_det .default ,
214+ torch .ops .aten ._linalg_svd .default ,
215+ torch .ops .aten ._native_batch_norm_legit_functional .default ,
216+ torch .ops .aten .linear .default ,
217+ torch .ops .aten .linalg_vector_norm .default ,
218+ torch .ops .aten .unfold .default ,
219+ torch .ops .aten .angle .default ,
220+ torch .ops .aten .rms_norm .default ,
221+ ]
222+ TO_EDGE_PRESERVE_OPS : tuple [torch ._ops .OpOverload , ...] = (
223+ torch .ops .aten .rms_norm .default ,
224+ )
225+
226+
213227def _lower_ep_to_edge (
214228 expo_program : ExportedProgram ,
215229 dump_graphs : bool = False ,
@@ -226,20 +240,11 @@ def _lower_ep_to_edge(
226240 compile_config = EdgeCompileConfig (
227241 _skip_dim_order = True ,
228242 # Allow specific non-core aten ops in the IR.
229- _core_aten_ops_exception_list = [
230- torch .ops .aten ._linalg_det .default ,
231- torch .ops .aten ._linalg_svd .default ,
232- torch .ops .aten ._native_batch_norm_legit_functional .default ,
233- torch .ops .aten .linear .default ,
234- torch .ops .aten .linalg_vector_norm .default ,
235- torch .ops .aten .unfold .default ,
236- torch .ops .aten .angle .default ,
237- torch .ops .aten .rms_norm .default ,
238- ]
243+ _core_aten_ops_exception_list = TO_EDGE_OP_EXCEPTION_LIST
239244 + (core_aten_exceptions or []),
240245 ),
241246 constant_methods = constant_methods ,
242- preserve_ops = ( torch . ops . aten . rms_norm . default ,) ,
247+ preserve_ops = TO_EDGE_PRESERVE_OPS ,
243248 )
244249
245250 if dump_graphs :
@@ -256,14 +261,20 @@ def export_to_edge(
256261 inputs : tuple [object , ...],
257262 dump_graphs : bool = False ,
258263 constant_methods : Optional [dict [str , object ]] = None ,
264+ core_aten_exceptions : Optional [list [torch ._ops .OpOverload ]] = None ,
259265) -> EdgeProgramManager :
260266 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
261267
262268 # Export the model into an ExportedProgram.
263269 expo_program = trace (model , inputs )
264270
271+ # Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+ expo_program = apply_torch_ops_passes (expo_program )
273+
265274 # Lower the model to edge IR.
266- edge_prog_manager = _lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
275+ edge_prog_manager = _lower_ep_to_edge (
276+ expo_program , dump_graphs , constant_methods , core_aten_exceptions
277+ )
267278
268279 return edge_prog_manager
269280
@@ -305,14 +316,7 @@ def _lower_ep_to_cadence(
305316 Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
306317 """
307318 edge_prog_manager = _lower_ep_to_edge (program , dump_graphs = dump_graphs )
308- cadence_passes = get_cadence_passes (opt_level )
309-
310- # Run a couple required passes for quant/dequant ops
311- cadence_prog_manager = edge_prog_manager .transform (
312- cast (
313- list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
314- )
315- )
319+ cadence_prog_manager = apply_exir_ops_passes (opt_level , edge_prog_manager )
316320 return cadence_prog_manager
317321
318322
@@ -323,14 +327,7 @@ def export_to_cadence(
323327 opt_level : int = 1 ,
324328) -> EdgeProgramManager :
325329 edge_prog_manager = export_to_edge (model , inputs , dump_graphs = dump_graphs )
326- cadence_passes = get_cadence_passes (opt_level )
327-
328- # Run a couple required passes for quant/dequant ops
329- cadence_prog_manager = edge_prog_manager .transform (
330- cast (
331- list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
332- )
333- )
330+ cadence_prog_manager = apply_exir_ops_passes (opt_level , edge_prog_manager )
334331 return cadence_prog_manager
335332
336333
@@ -367,15 +364,8 @@ def export_to_executorch_gen_etrecord(
367364 memory_config : Optional [MemoryConfig ] = None ,
368365 dump_graphs : bool = False ,
369366) -> ExecutorchProgramManager :
370- cadence_passes = get_cadence_passes (opt_level )
371367 edge_prog_manager = export_to_edge (model , inputs , dump_graphs )
372-
373- # Run a couple required passes for quant/dequant ops
374- cadence_prog_manager = edge_prog_manager .transform (
375- cast (
376- list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
377- )
378- )
368+ cadence_prog_manager = apply_exir_ops_passes (opt_level , edge_prog_manager )
379369
380370 # Print some information to terminal
381371 print_ops_info (
0 commit comments