Skip to content

Commit 44d7528

Browse files
committed
Arm backend: Add support for div.Tensor_mode op
Signed-off-by: Elena Zhelezina <[email protected]> Change-Id: I6477ffcf8164e40d34c1af5e72229ae5ef45b1ac
1 parent 899d7e5 commit 44d7528

File tree

5 files changed

+240
-0
lines changed

5 files changed

+240
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3838
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3939
from .decompose_div_pass import DecomposeDivPass # noqa
40+
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
4041
from .decompose_elu_pass import DecomposeEluPass # noqa
4142
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4243
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeCosineSimilarityPass,
4343
DecomposeCumsumPass,
4444
DecomposeDivPass,
45+
DecomposeDivTensorModePass,
4546
DecomposeEluPass,
4647
DecomposeEmbeddingPass,
4748
DecomposeExpm1Pass,
@@ -210,6 +211,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
210211
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
211212
)
212213
self.add_pass(DecomposeNotEqualPass())
214+
self.add_pass(DecomposeDivTensorModePass())
213215
self.add_pass(DecomposeDivPass())
214216
self.add_pass(DecomposeSoftmaxPass())
215217
self.add_pass(DecomposeGeluPass())
@@ -288,6 +290,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
288290
self.add_pass(DecomposeNotEqualPass())
289291
self.add_pass(DecomposeCosineSimilarityPass())
290292
self.add_pass(DecomposeGluPass())
293+
self.add_pass(DecomposeDivTensorModePass())
291294
self.add_pass(DecomposeDivPass())
292295
self.add_pass(DecomposeLeakyReLUPass())
293296
self.add_pass(DecomposeLinearVectorNormPass())
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass
12+
13+
edge_div_mode_ops = (exir_ops.edge.aten.div.Tensor_mode,)
14+
aten_div_mode_ops = (torch.ops.aten.div.Tensor_mode,)
15+
16+
edge_unary = {
17+
"div": exir_ops.edge.aten.div.Tensor,
18+
"floor": exir_ops.edge.aten.floor.default,
19+
"ceil": exir_ops.edge.aten.ceil.default,
20+
"full": exir_ops.edge.aten.full.default,
21+
"lt": exir_ops.edge.aten.lt.Tensor,
22+
"where": exir_ops.edge.aten.where.self,
23+
}
24+
25+
aten_unary = {
26+
"div": torch.ops.aten.div.Tensor,
27+
"floor": torch.ops.aten.floor.default,
28+
"ceil": torch.ops.aten.ceil.default,
29+
"full": torch.ops.aten.full.default,
30+
"lt": torch.ops.aten.lt.Tensor,
31+
"where": torch.ops.aten.where.self,
32+
}
33+
34+
35+
def _get_opset(op):
36+
if op in edge_div_mode_ops:
37+
return edge_unary
38+
if op in aten_div_mode_ops:
39+
return aten_unary
40+
raise RuntimeError(f"div.Tensor_mode not supported for op {op}")
41+
42+
43+
class DecomposeDivTensorModePass(ExportPass):
44+
"""
45+
Rewrites aten.div.Tensor_mode into
46+
47+
rounding_mode=None -> div(a, b)
48+
rounding_mode='floor' -> floor(div(a, b))
49+
rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b)))
50+
"""
51+
52+
def call_operator(self, op, args, kwargs, meta):
53+
if op not in (edge_div_mode_ops + aten_div_mode_ops):
54+
return super().call_operator(op, args, kwargs, meta)
55+
56+
opset = _get_opset(op)
57+
58+
a, b = args[0], args[1]
59+
rounding_mode = kwargs.get("rounding_mode", None)
60+
if rounding_mode is None and len(args) > 2:
61+
rounding_mode = args[2]
62+
63+
q = super().call_operator(opset["div"], (a, b), {}, meta)
64+
65+
if rounding_mode is None:
66+
return q
67+
68+
if rounding_mode == "floor":
69+
return super().call_operator(opset["floor"], (q,), {}, meta)
70+
71+
if rounding_mode == "trunc":
72+
zero = super().call_operator(
73+
opset["full"],
74+
args=((1,) * len(meta["val"].size()), 0.0),
75+
kwargs={"dtype": torch.float32},
76+
meta=meta,
77+
)
78+
lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta)
79+
ceilq = self.call_operator(opset["ceil"], (q,), {}, meta)
80+
floorq = self.call_operator(opset["floor"], (q,), {}, meta)
81+
return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta)
82+
83+
raise RuntimeError(
84+
f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}"
85+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def is_node_supported(
176176
exir_ops.edge.aten.hardtanh.default,
177177
exir_ops.edge.aten.hardswish.default,
178178
exir_ops.edge.aten.div.Tensor,
179+
exir_ops.edge.aten.div.Tensor_mode,
179180
exir_ops.edge.aten.eq.Tensor,
180181
exir_ops.edge.aten.eq.Scalar,
181182
exir_ops.edge.aten.erf.default,
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# This source code is licensed under the BSD-style license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
from typing import Tuple
6+
7+
import pytest
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
input_tt = Tuple[torch.Tensor, torch.Tensor]
20+
21+
22+
def make_float_div_inputs(B: int = 4, T: int = 64) -> input_tt:
23+
x = torch.randn(B, T)
24+
# guard against zero in denominator
25+
y = torch.randn(B, T).abs() + 1e-3
26+
return x, y
27+
28+
29+
class DivTensorModeFloat(torch.nn.Module):
30+
"""
31+
torch.div(x, y, rounding_mode=mode) with
32+
mode from {None, "floor", "trunc"}.
33+
"""
34+
35+
aten_ops = ["aten.div.Tensor_mode"]
36+
aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default"]
37+
38+
def __init__(self, mode=None):
39+
super().__init__()
40+
assert mode in (None, "floor", "trunc")
41+
self.mode = mode
42+
43+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
44+
return torch.div(x, y, rounding_mode=self.mode)
45+
46+
47+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
48+
def test_div_tensor_mode_tosa_FP(mode):
49+
50+
model = DivTensorModeFloat(mode)
51+
inputs = make_float_div_inputs()
52+
53+
pipeline = TosaPipelineFP[input_tt](
54+
model,
55+
inputs,
56+
aten_op=model.aten_ops,
57+
exir_op=[],
58+
use_to_edge_transform_and_lower=True,
59+
)
60+
pipeline.pop_stage("check_count.exir")
61+
pipeline.run()
62+
63+
64+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
65+
def test_div_tensor_mode_tosa_INT(mode):
66+
67+
model = DivTensorModeFloat(mode)
68+
inputs = make_float_div_inputs()
69+
70+
pipeline = TosaPipelineINT[input_tt](
71+
model,
72+
inputs,
73+
aten_op=model.aten_ops_int,
74+
exir_op=[],
75+
use_to_edge_transform_and_lower=True,
76+
)
77+
pipeline.pop_stage("check_count.exir")
78+
pipeline.run()
79+
80+
81+
@pytest.mark.parametrize("mode", [None, "floor"])
82+
def test_div_tensor_mode_u55_INT(mode):
83+
84+
model = DivTensorModeFloat(mode)
85+
inputs = make_float_div_inputs()
86+
87+
pipeline = EthosU55PipelineINT[input_tt](
88+
model,
89+
inputs,
90+
aten_ops=model.aten_ops_int,
91+
exir_ops=[],
92+
use_to_edge_transform_and_lower=True,
93+
run_on_fvp=True,
94+
)
95+
pipeline.run()
96+
97+
98+
@common.XfailIfNoCorstone320
99+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
100+
def test_div_tensor_mode_u85_INT(mode):
101+
102+
model = DivTensorModeFloat(mode)
103+
inputs = make_float_div_inputs()
104+
105+
pipeline = EthosU85PipelineINT[input_tt](
106+
model,
107+
inputs,
108+
aten_ops=model.aten_ops_int,
109+
exir_ops=[],
110+
use_to_edge_transform_and_lower=True,
111+
run_on_fvp=True,
112+
)
113+
pipeline.run()
114+
115+
116+
@common.SkipIfNoModelConverter
117+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
118+
def test_div_tensor_mode_vgf_INT(mode):
119+
120+
model = DivTensorModeFloat(mode)
121+
inputs = make_float_div_inputs()
122+
123+
pipeline = VgfPipeline[input_tt](
124+
model,
125+
inputs,
126+
aten_op=model.aten_ops_int,
127+
exir_op=[],
128+
tosa_version="TOSA-1.0+INT",
129+
use_to_edge_transform_and_lower=True,
130+
)
131+
pipeline.pop_stage("check_count.exir")
132+
pipeline.run()
133+
134+
135+
@common.SkipIfNoModelConverter
136+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
137+
def test_div_tensor_mode_vgf_FP(mode):
138+
139+
model = DivTensorModeFloat(mode)
140+
inputs = make_float_div_inputs()
141+
142+
pipeline = VgfPipeline[input_tt](
143+
model,
144+
inputs,
145+
aten_op=model.aten_ops,
146+
exir_op=[],
147+
tosa_version="TOSA-1.0+FP",
148+
use_to_edge_transform_and_lower=True,
149+
)
150+
pipeline.run()

0 commit comments

Comments
 (0)