Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
from .decompose_fmod_pass import DecomposeFmodPass # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
Expand Down
12 changes: 8 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DecomposeDivPass,
DecomposeEmbeddingPass,
DecomposeExpm1Pass,
DecomposeFmodPass,
DecomposeGeluPass,
DecomposeGluPass,
DecomposeGroupedConv,
Expand Down Expand Up @@ -185,6 +186,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeSinhPass())
self.add_pass(DecomposeSignPass())
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
self.add_pass(DecomposeFmodPass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
Expand Down Expand Up @@ -275,6 +277,12 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeCosineSimilarityPass())
self.add_pass(DecomposeGluPass())

if not self.tosa_spec.is_U55_subset:
# Uses where which is not supported on Ethos-U55
self.add_pass(DecomposeMaskedFill())
self.add_pass(DecomposeFmodPass())

self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeLinearVectorNormPass())
Expand All @@ -292,8 +300,4 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ReplaceInfValues())
self.add_pass(DecomposeSumPass())

if not self.tosa_spec.is_U55_subset:
# Uses where which is not supported on Ethos-U55
self.add_pass(DecomposeMaskedFill())

return self._transform(graph_module)
79 changes: 79 additions & 0 deletions backends/arm/_passes/decompose_fmod_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch

from executorch.backends.arm._passes import ArmPass
from executorch.exir.dialects._ops import ops as exir_ops

exir_op = (exir_ops.edge.aten.fmod.Tensor,)
aten_op = (torch.ops.aten.fmod.Tensor,)


def _get_decomposition(op) -> tuple:
if op in exir_op:
return (
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.ceil.default,
exir_ops.edge.aten.floor.default,
exir_ops.edge.aten.where.self,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.full_like.default,
)
if op in aten_op:
return (
torch.ops.aten.sub.Tensor,
torch.ops.aten.div.Tensor,
torch.ops.aten.mul.Tensor,
torch.ops.aten.ceil.default,
torch.ops.aten.floor.default,
torch.ops.aten.where.self,
torch.ops.aten.lt.Tensor,
torch.ops.aten.full_like.default,
)

raise Exception(f"Unable to get decomposition for {op}")


class DecomposeFmodPass(ArmPass):
"""
Decomposes fmod operator according to the following formula:
fmod(x, y) = x - x.div(y, rounding_mode=truncated) * y
"""

def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (exir_op + aten_op):
return super().call_operator(op, args, kwargs, meta, updated)

sub_op, div_op, mul_op, ceil_op, floor_op, where_op, lt_op, full_like_op = (
_get_decomposition(op)
)

x, y = args

div = super().call_operator(div_op, (x, y), {}, meta, True)

floor_round = super().call_operator(floor_op, (div,), {}, meta, True)
ceil_round = super().call_operator(ceil_op, (div,), {}, meta, True)

# Create a mask to determine which values are negative
# and use it to select the appropriate rounding method
# If the value is negative, use ceil, otherwise use floor
zeros = super().call_operator(full_like_op, (div, 0.0), {}, meta, True)
mask = super().call_operator(lt_op, (div, zeros), {}, meta, True)

rounded_values = super().call_operator(
where_op, (mask, ceil_round, floor_round), {}, meta, True
)

mul = super().call_operator(mul_op, (rounded_values, y), {}, meta, True)

out = super().call_operator(sub_op, (x, mul), {}, meta, True)

return out
2 changes: 2 additions & 0 deletions backends/arm/_passes/replace_scalar_with_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
exir_ops.edge.aten.bitwise_and.Scalar: exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Scalar: exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Scalar: exir_ops.edge.aten.bitwise_xor.Tensor,
exir_ops.edge.aten.fmod.Scalar: exir_ops.edge.aten.fmod.Tensor,
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
Expand All @@ -52,6 +53,7 @@
torch.ops.aten.bitwise_and.Scalar: torch.ops.aten.bitwise_and.Tensor,
torch.ops.aten.bitwise_or.Scalar: torch.ops.aten.bitwise_or.Tensor,
torch.ops.aten.bitwise_xor.Scalar: torch.ops.aten.bitwise_xor.Tensor,
torch.ops.aten.fmod.Scalar: torch.ops.aten.fmod.Tensor,
}


Expand Down
4 changes: 4 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def is_node_supported(
exir_ops.edge.aten.glu.default,
exir_ops.edge.aten.logit.default,
exir_ops.edge.aten.acos.default,
exir_ops.edge.aten.fmod.Tensor,
exir_ops.edge.aten.fmod.Scalar,
]

return supported
Expand Down Expand Up @@ -306,6 +308,8 @@ def is_node_supported(
exir_ops.edge.aten.addmm.default: None,
exir_ops.edge.aten.glu.default: None,
exir_ops.edge.aten.logit.default: None,
exir_ops.edge.aten.fmod.Scalar: None,
exir_ops.edge.aten.fmod.Tensor: None,
}

if node.target in needs_decomp_dict:
Expand Down
Loading
Loading