diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 36ca0b32dc5..39c249c5b3d 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -8,7 +8,7 @@ import logging from pathlib import Path -from typing import Callable, cast, Optional +from typing import Optional import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -32,7 +32,6 @@ ExecutorchBackendConfig, ExecutorchProgramManager, ) -from executorch.exir.pass_base import PassResult from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass from executorch.exir.program._program import to_edge_with_preserved_ops @@ -41,7 +40,7 @@ from torch.export.exported_program import ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from .passes import get_cadence_passes +from .passes import apply_exir_ops_passes, apply_torch_ops_passes from .utils import print_ops_info @@ -256,14 +255,20 @@ def export_to_edge( inputs: tuple[object, ...], dump_graphs: bool = False, constant_methods: Optional[dict[str, object]] = None, + core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None, ) -> EdgeProgramManager: assert isinstance(model, torch.nn.Module), "model should be an nn.Module" # Export the model into an ExportedProgram. expo_program = trace(model, inputs) + # Apply passes which transform the ExportedProgram before it gets lowered to edge. + expo_program = apply_torch_ops_passes(expo_program) + # Lower the model to edge IR. - edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods) + edge_prog_manager = _lower_ep_to_edge( + expo_program, dump_graphs, constant_methods, core_aten_exceptions + ) return edge_prog_manager @@ -305,14 +310,7 @@ def _lower_ep_to_cadence( Lower an existing ExportedProgram to edge IR and apply frontend optimization passes. """ edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs) - cadence_passes = get_cadence_passes(opt_level) - - # Run a couple required passes for quant/dequant ops - cadence_prog_manager = edge_prog_manager.transform( - cast( - list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes - ) - ) + cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager) return cadence_prog_manager @@ -323,14 +321,7 @@ def export_to_cadence( opt_level: int = 1, ) -> EdgeProgramManager: edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs) - cadence_passes = get_cadence_passes(opt_level) - - # Run a couple required passes for quant/dequant ops - cadence_prog_manager = edge_prog_manager.transform( - cast( - list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes - ) - ) + cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager) return cadence_prog_manager @@ -367,15 +358,8 @@ def export_to_executorch_gen_etrecord( memory_config: Optional[MemoryConfig] = None, dump_graphs: bool = False, ) -> ExecutorchProgramManager: - cadence_passes = get_cadence_passes(opt_level) edge_prog_manager = export_to_edge(model, inputs, dump_graphs) - - # Run a couple required passes for quant/dequant ops - cadence_prog_manager = edge_prog_manager.transform( - cast( - list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes - ) - ) + cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager) # Print some information to terminal print_ops_info( diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index ef42f399943..16d4dbde32b 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph: FuseCascadedTransposeOrPermuteOps, FuseCascadedViewOps, FuseQuantDequantToRequantizePass, + FuseMulTensorIntoQuantPass, FuseMulTensorIntoDequantPass, FuseMulScalarIntoDequantPass, FuseFullThenReshapePass, diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index 8355f7ef432..7d1d15de827 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, List, Optional +from typing import Any, Callable, cast, List, Optional import torch import torch.fx @@ -28,13 +28,18 @@ RemoveRedundantOps, ) from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph -from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph +from executorch.backends.cadence.aot.replace_ops import ( + CadenceReplaceOpsInGraph, + ReplaceMulTensorWithMulAndFullOpsPass, +) from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph +from executorch.exir import EdgeProgramManager from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes import dead_code_elimination_pass from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass from executorch.exir.passes.spec_prop_pass import SpecPropPass +from torch.export.exported_program import ExportedProgram @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]: return pytree.tree_flatten(passes)[0] -def get_cadence_passes( +def apply_exir_ops_passes( opt_level: int, -) -> List[Optional[PassResult]]: + edge_prog_manager: EdgeProgramManager, +) -> EdgeProgramManager: passes = get_passes_in_default_order() pass_filter = create_cadence_pass_filter(opt_level) - filtered_passes = [ - # pyre-ignore[20]: Expect argument graph_module - filtered_pass() + cadence_passes = [ + ( + lambda graph_module, filtered_pass=filtered_pass: filtered_pass()( + graph_module + ) + ) for filtered_pass in list(filter(pass_filter, passes)) ] - return filtered_passes + cadence_prog_manager = edge_prog_manager.transform( + cast( + list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes + ) + ) + return cadence_prog_manager + + +def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram: + """ + Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc. + expo_program is expected to be the output of the torch.export.export(). + """ + + aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [ + ReplaceMulTensorWithMulAndFullOpsPass() + ] + # TODO(T230417247): Use PassResult which is currently ignored. + PassManager(aten_passes)(expo_program.graph_module) + return expo_program