diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index de9a793b9aa..55daf92a5a9 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -88,8 +88,7 @@ from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa from .replace_scalar_with_tensor_pass import ( # noqa - ReplaceScalarWithTensorArgPassTOSABI, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from .rewrite_conv2d_pass import RewriteConv2dPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b579d910752..b491b445cc3 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -89,8 +89,7 @@ QuantizeOperatorArguments, RemoveNoopPass, ReplaceInfValues, - ReplaceScalarWithTensorArgPassTOSABI, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, RetraceFoldedDtypesPass, RewriteConv2dPass, RewriteMatmulPass, @@ -174,7 +173,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(CastToInt32Pass()) self.add_pass(CastBoolToInt8Pass()) - self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) + self.add_pass(ReplaceScalarWithTensorByProfilePass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) self.add_pass(ConvertELUParamsPass()) @@ -244,7 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(DecomposeSinhPass()) self.add_pass(DecomposeSignPass()) self.add_pass(DecomposeDivTensorModePass()) - self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) + self.add_pass(ReplaceScalarWithTensorByProfilePass()) self.add_pass(DecomposeEmbeddingPass()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -337,7 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeAddmmPass()) self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeAddSubAlphaPass()) - self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) + self.add_pass(ReplaceScalarWithTensorByProfilePass()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) diff --git a/backends/arm/_passes/decompose_acosh_pass.py b/backends/arm/_passes/decompose_acosh_pass.py index 509849fce4e..1d7929e11ed 100644 --- a/backends/arm/_passes/decompose_acosh_pass.py +++ b/backends/arm/_passes/decompose_acosh_pass.py @@ -13,7 +13,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -33,7 +33,7 @@ class DecomposeAcoshPass(ArmPass): DecomposeSqrtPass, InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } diff --git a/backends/arm/_passes/decompose_asin_and_acos_pass.py b/backends/arm/_passes/decompose_asin_and_acos_pass.py index 5b1c575e9c9..b29833190f2 100644 --- a/backends/arm/_passes/decompose_asin_and_acos_pass.py +++ b/backends/arm/_passes/decompose_asin_and_acos_pass.py @@ -20,7 +20,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -72,7 +72,7 @@ class DecomposeAsinAndAcosPass(ArmPass): ConvertFullLikeToFullPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, } def _build_polynomial( diff --git a/backends/arm/_passes/decompose_asinh_pass.py b/backends/arm/_passes/decompose_asinh_pass.py index 088230ca4b2..8b59b50fca8 100644 --- a/backends/arm/_passes/decompose_asinh_pass.py +++ b/backends/arm/_passes/decompose_asinh_pass.py @@ -14,7 +14,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -34,7 +34,7 @@ class DecomposeAsinhPass(ArmPass): DecomposeSqrtPass, InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } diff --git a/backends/arm/_passes/decompose_atan_pass.py b/backends/arm/_passes/decompose_atan_pass.py index 03ed62e7870..6f1adccd257 100644 --- a/backends/arm/_passes/decompose_atan_pass.py +++ b/backends/arm/_passes/decompose_atan_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass): InsertTableOpsPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, } def _rational_approximation(self, z, ops, meta): diff --git a/backends/arm/_passes/decompose_atanh_pass.py b/backends/arm/_passes/decompose_atanh_pass.py index 2c8347e7e9f..1a41e77eacc 100644 --- a/backends/arm/_passes/decompose_atanh_pass.py +++ b/backends/arm/_passes/decompose_atanh_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass): InsertTableOpsPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, } def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_cosh_pass.py b/backends/arm/_passes/decompose_cosh_pass.py index cbfbd5783e2..6716ba499ad 100644 --- a/backends/arm/_passes/decompose_cosh_pass.py +++ b/backends/arm/_passes/decompose_cosh_pass.py @@ -10,7 +10,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } diff --git a/backends/arm/_passes/decompose_expm1_pass.py b/backends/arm/_passes/decompose_expm1_pass.py index 5de03cbf102..0fe95d37ba2 100644 --- a/backends/arm/_passes/decompose_expm1_pass.py +++ b/backends/arm/_passes/decompose_expm1_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass): ConvertIntPowToMuls, InsertTableOpsPass, DecomposeDivPass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, MatchArgRanksPass, } diff --git a/backends/arm/_passes/decompose_logit_pass.py b/backends/arm/_passes/decompose_logit_pass.py index 213b8f038e8..69a250b41cb 100644 --- a/backends/arm/_passes/decompose_logit_pass.py +++ b/backends/arm/_passes/decompose_logit_pass.py @@ -12,7 +12,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -73,7 +73,7 @@ class DecomposeLogitPass(ArmPass): InsertTableOpsPass, MatchArgRanksPass, MatchArgDtypePass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, } def call_operator(self, op, args, kwargs, meta): diff --git a/backends/arm/_passes/decompose_sinh_pass.py b/backends/arm/_passes/decompose_sinh_pass.py index acb18df3134..772cc7c4741 100644 --- a/backends/arm/_passes/decompose_sinh_pass.py +++ b/backends/arm/_passes/decompose_sinh_pass.py @@ -11,7 +11,7 @@ from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -36,7 +36,7 @@ class DecomposeSinhPass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = { InsertTableOpsPass, MatchArgRanksPass, - ReplaceScalarWithTensorArgPassTOSAMI, + ReplaceScalarWithTensorByProfilePass, MatchArgDtypePass, } diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index f6ef056f677..e1eee568d39 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -9,6 +9,8 @@ from typing import Dict, Set, Type, Union import torch + +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, ) @@ -17,6 +19,8 @@ from executorch.exir.dialects.edge._ops import EdgeOpOverload from executorch.exir.pass_base import ExportPass +from .arm_pass import ArmPass + # Operators that are included for both TOSA profiles _common_ops: Dict[ @@ -55,23 +59,51 @@ torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor, } +_fp_profile_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = _common_ops | { + exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, + torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, +} -class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass): - _passes_required_after: Set[Type[ExportPass]] = set() +_int_profile_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = _common_ops - scalar_to_tensor_ops = _common_ops | { - exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor, - torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor, - } +_all_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = ( + _fp_profile_ops | _int_profile_ops +) - def __init__(self): - super().__init__(self.scalar_to_tensor_ops) +class ReplaceScalarWithTensorByProfilePass(ReplaceScalarWithTensorArgPass, ArmPass): + """Profile-aware scalar-to-tensor replacement pass for binary ops.""" -class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass): _passes_required_after: Set[Type[ExportPass]] = set() - scalar_to_tensor_ops = _common_ops - def __init__(self): - super().__init__(self.scalar_to_tensor_ops) + # Initialize base (ReplaceScalarWithTensorArgPass) with the full + # superset which will make the superclass handle ops in _all_ops. + # Actual selection is done per-call in call_operator. + super().__init__(_all_ops) + + def call_operator(self, op, args, kwargs, meta): + tosa_spec = get_context_spec() + + if tosa_spec.support_integer(): + included_ops = _int_profile_ops + elif tosa_spec.support_float(): + included_ops = _fp_profile_ops + else: + raise ValueError("Profile must support either INT or FP") + + if op in included_ops: + # Include this op based on the current profile. + return super().call_operator(op, args, kwargs, meta) + else: + # Do not handle; forward unchanged. + return ExportPass.call_operator(self, op, args, kwargs, meta)