Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
from .remove_noop_pass import RemoveNoopPass # noqa
from .replace_scalar_with_tensor_pass import ( # noqa
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
from .rewrite_matmul import RewriteMatmulPass # noqa
Expand Down
9 changes: 4 additions & 5 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@
QuantizeOperatorArguments,
RemoveNoopPass,
ReplaceInfValues,
ReplaceScalarWithTensorArgPassTOSABI,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
RetraceFoldedDtypesPass,
RewriteConv2dPass,
RewriteMatmulPass,
Expand Down Expand Up @@ -174,7 +173,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(CastToInt32Pass())

self.add_pass(CastBoolToInt8Pass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(ConvertELUParamsPass())
Expand Down Expand Up @@ -244,7 +243,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeSinhPass())
self.add_pass(DecomposeSignPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
Expand Down Expand Up @@ -337,7 +336,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeAddmmPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_acosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -33,7 +33,7 @@ class DecomposeAcoshPass(ArmPass):
DecomposeSqrtPass,
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
MatchArgDtypePass,
}

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -72,7 +72,7 @@ class DecomposeAsinAndAcosPass(ArmPass):
ConvertFullLikeToFullPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
}

def _build_polynomial(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_asinh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -34,7 +34,7 @@ class DecomposeAsinhPass(ArmPass):
DecomposeSqrtPass,
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
MatchArgDtypePass,
}

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_atan_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -47,7 +47,7 @@ class DecomposeAtanPass(ArmPass):
InsertTableOpsPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
}

def _rational_approximation(self, z, ops, meta):
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_atanh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -43,7 +43,7 @@ class DecomposeAtanhPass(ArmPass):
InsertTableOpsPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
}

def call_operator(self, op, args, kwargs, meta):
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_cosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -31,7 +31,7 @@ class DecomposeCoshPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
MatchArgDtypePass,
}

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_expm1_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -83,7 +83,7 @@ class DecomposeExpm1Pass(ArmPass):
ConvertIntPowToMuls,
InsertTableOpsPass,
DecomposeDivPass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
MatchArgDtypePass,
MatchArgRanksPass,
}
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_logit_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -73,7 +73,7 @@ class DecomposeLogitPass(ArmPass):
InsertTableOpsPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
}

def call_operator(self, op, args, kwargs, meta):
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/decompose_sinh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand All @@ -36,7 +36,7 @@ class DecomposeSinhPass(ArmPass):
_passes_required_after: Set[Type[ExportPass]] = {
InsertTableOpsPass,
MatchArgRanksPass,
ReplaceScalarWithTensorArgPassTOSAMI,
ReplaceScalarWithTensorByProfilePass,
MatchArgDtypePass,
}

Expand Down
56 changes: 44 additions & 12 deletions backends/arm/_passes/replace_scalar_with_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Dict, Set, Type, Union

import torch

from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.backends.transforms.replace_scalar_with_tensor import (
ReplaceScalarWithTensorArgPass,
)
Expand All @@ -17,6 +19,8 @@
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass

from .arm_pass import ArmPass


# Operators that are included for both TOSA profiles
_common_ops: Dict[
Expand Down Expand Up @@ -55,23 +59,51 @@
torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor,
}

_fp_profile_ops: Dict[
Union[EdgeOpOverload, torch._ops.OpOverload],
Union[EdgeOpOverload, torch._ops.OpOverload],
] = _common_ops | {
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
}

class ReplaceScalarWithTensorArgPassTOSAMI(ReplaceScalarWithTensorArgPass):
_passes_required_after: Set[Type[ExportPass]] = set()
_int_profile_ops: Dict[
Union[EdgeOpOverload, torch._ops.OpOverload],
Union[EdgeOpOverload, torch._ops.OpOverload],
] = _common_ops

scalar_to_tensor_ops = _common_ops | {
exir_ops.edge.aten.pow.Tensor_Scalar: exir_ops.edge.aten.pow.Tensor_Tensor,
torch.ops.aten.pow.Tensor_Scalar: torch.ops.aten.pow.Tensor_Tensor,
}
_all_ops: Dict[
Union[EdgeOpOverload, torch._ops.OpOverload],
Union[EdgeOpOverload, torch._ops.OpOverload],
] = (
_fp_profile_ops | _int_profile_ops
)

def __init__(self):
super().__init__(self.scalar_to_tensor_ops)

class ReplaceScalarWithTensorByProfilePass(ReplaceScalarWithTensorArgPass, ArmPass):
"""Profile-aware scalar-to-tensor replacement pass for binary ops."""

class ReplaceScalarWithTensorArgPassTOSABI(ReplaceScalarWithTensorArgPass):
_passes_required_after: Set[Type[ExportPass]] = set()

scalar_to_tensor_ops = _common_ops

def __init__(self):
super().__init__(self.scalar_to_tensor_ops)
# Initialize base (ReplaceScalarWithTensorArgPass) with the full
# superset which will make the superclass handle ops in _all_ops.
# Actual selection is done per-call in call_operator.
super().__init__(_all_ops)

def call_operator(self, op, args, kwargs, meta):
tosa_spec = get_context_spec()

if tosa_spec.support_integer():
included_ops = _int_profile_ops
elif tosa_spec.support_float():
included_ops = _fp_profile_ops
else:
raise ValueError("Profile must support either INT or FP")

if op in included_ops:
# Include this op based on the current profile.
return super().call_operator(op, args, kwargs, meta)
else:
# Do not handle; forward unchanged.
return ExportPass.call_operator(self, op, args, kwargs, meta)
Loading