Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
from .decompose_elu_pass import DecomposeEluPass # noqa
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
DecomposeDivPass,
DecomposeDivTensorModePass,
DecomposeEluPass,
DecomposeEmbeddingPass,
DecomposeExpm1Pass,
Expand Down Expand Up @@ -211,6 +212,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxPass())
self.add_pass(DecomposeGeluPass())
Expand Down Expand Up @@ -289,6 +291,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeCosineSimilarityPass())
self.add_pass(DecomposeGluPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeLinearVectorNormPass())
Expand Down
84 changes: 84 additions & 0 deletions backends/arm/_passes/decompose_div_tensor_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_div_mode_ops = (exir_ops.edge.aten.div.Tensor_mode,)
aten_div_mode_ops = (torch.ops.aten.div.Tensor_mode,)

edge_unary = {
"div": exir_ops.edge.aten.div.Tensor,
"floor": exir_ops.edge.aten.floor.default,
"ceil": exir_ops.edge.aten.ceil.default,
"full": exir_ops.edge.aten.full.default,
"lt": exir_ops.edge.aten.lt.Tensor,
"where": exir_ops.edge.aten.where.self,
}

aten_unary = {
"div": torch.ops.aten.div.Tensor,
"floor": torch.ops.aten.floor.default,
"ceil": torch.ops.aten.ceil.default,
"full": torch.ops.aten.full.default,
"lt": torch.ops.aten.lt.Tensor,
"where": torch.ops.aten.where.self,
}


def _get_opset(op):
if op in edge_div_mode_ops:
return edge_unary
if op in aten_div_mode_ops:
return aten_unary
raise RuntimeError(f"div.Tensor_mode not supported for op {op}")


class DecomposeDivTensorModePass(ExportPass):
"""
Rewrites aten.div.Tensor_mode into

rounding_mode=None -> div(a, b)
rounding_mode='floor' -> floor(div(a, b))
rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b)))
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (edge_div_mode_ops + aten_div_mode_ops):
return super().call_operator(op, args, kwargs, meta)

opset = _get_opset(op)

a, b = args[0], args[1]
rounding_mode = kwargs.get("rounding_mode", None)
if rounding_mode is None and len(args) > 2:
rounding_mode = args[2]

q = super().call_operator(opset["div"], (a, b), {}, meta)

if rounding_mode is None:
return q

if rounding_mode == "floor":
return super().call_operator(opset["floor"], (q,), {}, meta)

if rounding_mode == "trunc":
zero = super().call_operator(
opset["full"],
args=((1,) * len(meta["val"].size()), 0.0),
kwargs={"dtype": torch.float32},
meta=meta,
)
lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta)
ceilq = self.call_operator(opset["ceil"], (q,), {}, meta)
floorq = self.call_operator(opset["floor"], (q,), {}, meta)
return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta)

raise RuntimeError(
f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def is_node_supported(
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.erf.default,
Expand Down
150 changes: 150 additions & 0 deletions backends/arm/test/ops/test_div_tensor_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest
import torch

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

input_tt = Tuple[torch.Tensor, torch.Tensor]


def make_float_div_inputs(B: int = 4, T: int = 64) -> input_tt:
x = torch.randn(B, T)
# guard against zero in denominator
y = torch.randn(B, T).abs() + 1e-3
return x, y


class DivTensorModeFloat(torch.nn.Module):
"""
torch.div(x, y, rounding_mode=mode) with
mode from {None, "floor", "trunc"}.
"""

aten_ops = ["aten.div.Tensor_mode"]
aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default"]

def __init__(self, mode=None):
super().__init__()
assert mode in (None, "floor", "trunc")
self.mode = mode

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.div(x, y, rounding_mode=self.mode)


@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
def test_div_tensor_mode_tosa_FP(mode):

model = DivTensorModeFloat(mode)
inputs = make_float_div_inputs()

pipeline = TosaPipelineFP[input_tt](
model,
inputs,
aten_op=model.aten_ops,
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.pop_stage("check_count.exir")
pipeline.run()


@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
def test_div_tensor_mode_tosa_INT(mode):

model = DivTensorModeFloat(mode)
inputs = make_float_div_inputs()

pipeline = TosaPipelineINT[input_tt](
model,
inputs,
aten_op=model.aten_ops_int,
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.pop_stage("check_count.exir")
pipeline.run()


@pytest.mark.parametrize("mode", [None, "floor"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add @common.XfailIfNoCorstone300

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yes, missed yhis

def test_div_tensor_mode_u55_INT(mode):

model = DivTensorModeFloat(mode)
inputs = make_float_div_inputs()

pipeline = EthosU55PipelineINT[input_tt](
model,
inputs,
aten_ops=model.aten_ops_int,
exir_ops=[],
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)
pipeline.run()


@common.XfailIfNoCorstone320
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
def test_div_tensor_mode_u85_INT(mode):

model = DivTensorModeFloat(mode)
inputs = make_float_div_inputs()

pipeline = EthosU85PipelineINT[input_tt](
model,
inputs,
aten_ops=model.aten_ops_int,
exir_ops=[],
use_to_edge_transform_and_lower=True,
run_on_fvp=True,
)
pipeline.run()


@common.SkipIfNoModelConverter
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
def test_div_tensor_mode_vgf_INT(mode):

model = DivTensorModeFloat(mode)
inputs = make_float_div_inputs()

pipeline = VgfPipeline[input_tt](
model,
inputs,
aten_op=model.aten_ops_int,
exir_op=[],
tosa_version="TOSA-1.0+INT",
use_to_edge_transform_and_lower=True,
)
pipeline.pop_stage("check_count.exir")
pipeline.run()


@common.SkipIfNoModelConverter
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
def test_div_tensor_mode_vgf_FP(mode):

model = DivTensorModeFloat(mode)
inputs = make_float_div_inputs()

pipeline = VgfPipeline[input_tt](
model,
inputs,
aten_op=model.aten_ops,
exir_op=[],
tosa_version="TOSA-1.0+FP",
use_to_edge_transform_and_lower=True,
)
pipeline.run()
Loading