diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 9157f7d6a4b..5ebcf8b43b1 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -159,7 +159,12 @@ def _transform(self, graph_module: GraphModule): def _tosa_pipeline( self, exported_program: ExportedProgram, graph_module: GraphModule ) -> GraphModule: + # Preprocessing passes + self.add_pass(AnnotateOutputDimOrderPass()) + + # Node transformation passes (pre q/dq folding) + self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) self.add_pass(ConvertToClampPass()) @@ -174,8 +179,19 @@ def _tosa_pipeline( self.add_pass(ConvertELUParamsPass()) self.add_pass(ConvertSplitToSlicePass()) self.add_pass(QuantizeOperatorArguments()) + + # Fold Q/DQ nodes, insert INT8/INT32 rescales. + self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg] self.add_pass(FuseDuplicateUsersPass()) + # TODO: DecomposeLinearPass should run after InsertRescaleInt32Pass or + # before FoldAndAnnotateQParamsPass but is unable to at the moment. + # Ticket: MLETORCH-1539 + self.add_pass(DecomposeLinearPass()) + self.add_pass(InsertRescaleInt32Pass()) + + # Node transformation passes (post q/dq folding) + self.add_pass(DecomposeExpm1Pass()) self.add_pass(DecomposeLogitPass()) self.add_pass(DecomposeMaskedFill()) @@ -196,57 +212,67 @@ def _tosa_pipeline( self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeFloorDividePass()) self.add_pass(DecomposeDivTensorModePass()) + self.add_pass(DecomposeGeluPass()) + self.add_pass(DecomposeAddSubAlphaPass()) + self.add_pass(DecomposeGroupedConv()) + self.add_pass(Conv1dUnsqueezePass()) + + # Scalars -> tensors, match tensor dtypes and ranks. + self.add_pass(ReplaceScalarWithTensorByProfilePass()) + self.add_pass(ConvertFullLikeToFullPass()) + self.add_pass(MatchArgDtypePass()) + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) + # TODO: Move DecomposeNotEqualPass to before or after this block of + # passes. Ticket: MLETORCH-1540 + self.add_pass(DecomposeNotEqualPass()) + self.add_pass(MatchArgRanksPass(exported_program)) + self.add_pass(FuseConstantArgsPass(exported_program)) + + # Node transformation passes (post scalar-removal) + self.add_pass(DecomposeRemainderPass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseBatchnorm2DPass(exported_program)) self.add_pass(ConvertMmToBmmPass()) self.add_pass(DecomposeGluPass()) - self.add_pass(DecomposeLinearPass()) self.add_pass(DecomposeLeakyReLUPass()) - self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeDivPass()) - self.add_pass(DecomposeAddSubAlphaPass()) self.add_pass(DecomposeSoftmaxPass()) - self.add_pass(DecomposeGeluPass()) - self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) - self.add_pass(MatchArgDtypePass()) - self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) - self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeAdaptiveAvgPool2dPass()) self.add_pass(DecomposeAvgPool2d()) self.add_pass( DecorateFp32toInt32CastingPass() ) # Require that no new fp32->int32 is introduced after this pass self.add_pass(ComputeConstantOpsAOT(exported_program)) - - self.add_pass(DecomposeGroupedConv()) self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(DecomposeCumsumPass(exported_program)) - self.add_pass(Conv1dUnsqueezePass()) self.add_pass(DecomposeMaxPool2DPass()) self.add_pass(SizeAdjustInputPass()) self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(CastToInt32Pass()) self.add_pass(BroadcastArgsPass()) - self.add_pass(ConvertPermuteSingletonToViewPass()) self.add_pass(FuseViewCopyTransform()) - self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(DecomposeConv2dWithInt16ActivationPass()) - self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) + self.add_pass(DecomposeSumPass()) self.add_pass(InsertTableOpsPass(exported_program)) + + # Aten -> TOSA transformation passes + self.add_pass(RewriteUpsamplePass()) self.add_pass(RewriteConv2dPass(exported_program)) self.add_pass(RewriteMatmulPass()) + + # Postprocessing/cleanup passes + + self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) - self.add_pass(InsertRescaleInt32Pass()) - self.add_pass(DecomposeSumPass()) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) diff --git a/backends/arm/_passes/decompose_linear_pass.py b/backends/arm/_passes/decompose_linear_pass.py index ffe63f8cb65..a04a87775a3 100644 --- a/backends/arm/_passes/decompose_linear_pass.py +++ b/backends/arm/_passes/decompose_linear_pass.py @@ -12,6 +12,7 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm._passes.insert_rescales_pass import InsertRescaleInt32Pass from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -26,7 +27,7 @@ class DecomposeLinearPass(ArmPass): output = view(conv2d) """ - _passes_required_after: Set[Type[ExportPass]] = set() + _passes_required_after: Set[Type[ExportPass]] = {InsertRescaleInt32Pass} def call(self, graph_module): for node in graph_module.graph.nodes: diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index c09df48f7be..1a833daf1de 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -57,6 +57,7 @@ def __init__(self, exported_program: ExportedProgram) -> None: exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, + exir_ops.edge.aten.remainder.Tensor, exir_ops.edge.aten.where.self, exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor,