diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index a9abb4ab183..cb90eef01d1 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -37,6 +37,7 @@ from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_cumsum_pass import DecomposeCumsumPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa +from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa from .decompose_elu_pass import DecomposeEluPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa from .decompose_expm1_pass import DecomposeExpm1Pass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 8727e339b53..6aae943881d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -42,6 +42,7 @@ DecomposeCosineSimilarityPass, DecomposeCumsumPass, DecomposeDivPass, + DecomposeDivTensorModePass, DecomposeEluPass, DecomposeEmbeddingPass, DecomposeExpm1Pass, @@ -211,6 +212,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) ) self.add_pass(DecomposeNotEqualPass()) + self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeSoftmaxPass()) self.add_pass(DecomposeGeluPass()) @@ -289,6 +291,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeNotEqualPass()) self.add_pass(DecomposeCosineSimilarityPass()) self.add_pass(DecomposeGluPass()) + self.add_pass(DecomposeDivTensorModePass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeLinearVectorNormPass()) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py new file mode 100644 index 00000000000..0e6b40afbb2 --- /dev/null +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_div_mode_ops = (exir_ops.edge.aten.div.Tensor_mode,) +aten_div_mode_ops = (torch.ops.aten.div.Tensor_mode,) + +edge_unary = { + "div": exir_ops.edge.aten.div.Tensor, + "floor": exir_ops.edge.aten.floor.default, + "ceil": exir_ops.edge.aten.ceil.default, + "full": exir_ops.edge.aten.full.default, + "lt": exir_ops.edge.aten.lt.Tensor, + "where": exir_ops.edge.aten.where.self, +} + +aten_unary = { + "div": torch.ops.aten.div.Tensor, + "floor": torch.ops.aten.floor.default, + "ceil": torch.ops.aten.ceil.default, + "full": torch.ops.aten.full.default, + "lt": torch.ops.aten.lt.Tensor, + "where": torch.ops.aten.where.self, +} + + +def _get_opset(op): + if op in edge_div_mode_ops: + return edge_unary + if op in aten_div_mode_ops: + return aten_unary + raise RuntimeError(f"div.Tensor_mode not supported for op {op}") + + +class DecomposeDivTensorModePass(ExportPass): + """ + Rewrites aten.div.Tensor_mode into + + rounding_mode=None -> div(a, b) + rounding_mode='floor' -> floor(div(a, b)) + rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b))) + """ + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_div_mode_ops + aten_div_mode_ops): + return super().call_operator(op, args, kwargs, meta) + + opset = _get_opset(op) + + a, b = args[0], args[1] + rounding_mode = kwargs.get("rounding_mode", None) + if rounding_mode is None and len(args) > 2: + rounding_mode = args[2] + + q = super().call_operator(opset["div"], (a, b), {}, meta) + + if rounding_mode is None: + return q + + if rounding_mode == "floor": + return super().call_operator(opset["floor"], (q,), {}, meta) + + if rounding_mode == "trunc": + zero = super().call_operator( + opset["full"], + args=((1,) * len(meta["val"].size()), 0.0), + kwargs={"dtype": torch.float32}, + meta=meta, + ) + lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) + ceilq = self.call_operator(opset["ceil"], (q,), {}, meta) + floorq = self.call_operator(opset["floor"], (q,), {}, meta) + return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta) + + raise RuntimeError( + f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}" + ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 0645fde725e..9fa84d051d5 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -176,6 +176,7 @@ def is_node_supported( exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardswish.default, exir_ops.edge.aten.div.Tensor, + exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.erf.default, diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py new file mode 100644 index 00000000000..f78aca85bcd --- /dev/null +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -0,0 +1,151 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import pytest +import torch + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + EthosU85PipelineINT, + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +input_tt = Tuple[torch.Tensor, torch.Tensor] + + +def make_float_div_inputs(B: int = 4, T: int = 64) -> input_tt: + x = torch.randn(B, T) + # guard against zero in denominator + y = torch.randn(B, T).abs() + 1e-3 + return x, y + + +class DivTensorModeFloat(torch.nn.Module): + """ + torch.div(x, y, rounding_mode=mode) with + mode from {None, "floor", "trunc"}. + """ + + aten_ops = ["aten.div.Tensor_mode"] + aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default"] + + def __init__(self, mode=None): + super().__init__() + assert mode in (None, "floor", "trunc") + self.mode = mode + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.div(x, y, rounding_mode=self.mode) + + +@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) +def test_div_tensor_mode_tosa_FP(mode): + + model = DivTensorModeFloat(mode) + inputs = make_float_div_inputs() + + pipeline = TosaPipelineFP[input_tt]( + model, + inputs, + aten_op=model.aten_ops, + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) +def test_div_tensor_mode_tosa_INT(mode): + + model = DivTensorModeFloat(mode) + inputs = make_float_div_inputs() + + pipeline = TosaPipelineINT[input_tt]( + model, + inputs, + aten_op=model.aten_ops_int, + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +@common.XfailIfNoCorstone300 +@pytest.mark.parametrize("mode", [None, "floor"]) +def test_div_tensor_mode_u55_INT(mode): + + model = DivTensorModeFloat(mode) + inputs = make_float_div_inputs() + + pipeline = EthosU55PipelineINT[input_tt]( + model, + inputs, + aten_ops=model.aten_ops_int, + exir_ops=[], + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + pipeline.run() + + +@common.XfailIfNoCorstone320 +@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) +def test_div_tensor_mode_u85_INT(mode): + + model = DivTensorModeFloat(mode) + inputs = make_float_div_inputs() + + pipeline = EthosU85PipelineINT[input_tt]( + model, + inputs, + aten_ops=model.aten_ops_int, + exir_ops=[], + use_to_edge_transform_and_lower=True, + run_on_fvp=True, + ) + pipeline.run() + + +@common.SkipIfNoModelConverter +@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) +def test_div_tensor_mode_vgf_INT(mode): + + model = DivTensorModeFloat(mode) + inputs = make_float_div_inputs() + + pipeline = VgfPipeline[input_tt]( + model, + inputs, + aten_op=model.aten_ops_int, + exir_op=[], + tosa_version="TOSA-1.0+INT", + use_to_edge_transform_and_lower=True, + ) + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +@common.SkipIfNoModelConverter +@pytest.mark.parametrize("mode", [None, "floor", "trunc"]) +def test_div_tensor_mode_vgf_FP(mode): + + model = DivTensorModeFloat(mode) + inputs = make_float_div_inputs() + + pipeline = VgfPipeline[input_tt]( + model, + inputs, + aten_op=model.aten_ops, + exir_op=[], + tosa_version="TOSA-1.0+FP", + use_to_edge_transform_and_lower=True, + ) + pipeline.run()