Skip to content

Commit 7a7e939

Browse files
authored
Arm backend: Add decomposition of div.Tensor_mode (pytorch#13940)
Add decomposition of div.Tensor_mode Signed-off-by: Elena Zhelezina <[email protected]>
1 parent 2261848 commit 7a7e939

File tree

5 files changed

+240
-0
lines changed

5 files changed

+240
-0
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
3838
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3939
from .decompose_div_pass import DecomposeDivPass # noqa
40+
from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa
4041
from .decompose_elu_pass import DecomposeEluPass # noqa
4142
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
4243
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
DecomposeCosineSimilarityPass,
4343
DecomposeCumsumPass,
4444
DecomposeDivPass,
45+
DecomposeDivTensorModePass,
4546
DecomposeEluPass,
4647
DecomposeEmbeddingPass,
4748
DecomposeExpm1Pass,
@@ -211,6 +212,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
211212
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
212213
)
213214
self.add_pass(DecomposeNotEqualPass())
215+
self.add_pass(DecomposeDivTensorModePass())
214216
self.add_pass(DecomposeDivPass())
215217
self.add_pass(DecomposeSoftmaxPass())
216218
self.add_pass(DecomposeGeluPass())
@@ -289,6 +291,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
289291
self.add_pass(DecomposeNotEqualPass())
290292
self.add_pass(DecomposeCosineSimilarityPass())
291293
self.add_pass(DecomposeGluPass())
294+
self.add_pass(DecomposeDivTensorModePass())
292295
self.add_pass(DecomposeDivPass())
293296
self.add_pass(DecomposeLeakyReLUPass())
294297
self.add_pass(DecomposeLinearVectorNormPass())
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import torch
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
12+
edge_div_mode_ops = (exir_ops.edge.aten.div.Tensor_mode,)
13+
aten_div_mode_ops = (torch.ops.aten.div.Tensor_mode,)
14+
15+
edge_unary = {
16+
"div": exir_ops.edge.aten.div.Tensor,
17+
"floor": exir_ops.edge.aten.floor.default,
18+
"ceil": exir_ops.edge.aten.ceil.default,
19+
"full": exir_ops.edge.aten.full.default,
20+
"lt": exir_ops.edge.aten.lt.Tensor,
21+
"where": exir_ops.edge.aten.where.self,
22+
}
23+
24+
aten_unary = {
25+
"div": torch.ops.aten.div.Tensor,
26+
"floor": torch.ops.aten.floor.default,
27+
"ceil": torch.ops.aten.ceil.default,
28+
"full": torch.ops.aten.full.default,
29+
"lt": torch.ops.aten.lt.Tensor,
30+
"where": torch.ops.aten.where.self,
31+
}
32+
33+
34+
def _get_opset(op):
35+
if op in edge_div_mode_ops:
36+
return edge_unary
37+
if op in aten_div_mode_ops:
38+
return aten_unary
39+
raise RuntimeError(f"div.Tensor_mode not supported for op {op}")
40+
41+
42+
class DecomposeDivTensorModePass(ExportPass):
43+
"""
44+
Rewrites aten.div.Tensor_mode into
45+
46+
rounding_mode=None -> div(a, b)
47+
rounding_mode='floor' -> floor(div(a, b))
48+
rounding_mode='trunc' -> where(div(a,b) < 0, ceil(div(a,b)), floor(div(a,b)))
49+
"""
50+
51+
def call_operator(self, op, args, kwargs, meta):
52+
if op not in (edge_div_mode_ops + aten_div_mode_ops):
53+
return super().call_operator(op, args, kwargs, meta)
54+
55+
opset = _get_opset(op)
56+
57+
a, b = args[0], args[1]
58+
rounding_mode = kwargs.get("rounding_mode", None)
59+
if rounding_mode is None and len(args) > 2:
60+
rounding_mode = args[2]
61+
62+
q = super().call_operator(opset["div"], (a, b), {}, meta)
63+
64+
if rounding_mode is None:
65+
return q
66+
67+
if rounding_mode == "floor":
68+
return super().call_operator(opset["floor"], (q,), {}, meta)
69+
70+
if rounding_mode == "trunc":
71+
zero = super().call_operator(
72+
opset["full"],
73+
args=((1,) * len(meta["val"].size()), 0.0),
74+
kwargs={"dtype": torch.float32},
75+
meta=meta,
76+
)
77+
lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta)
78+
ceilq = self.call_operator(opset["ceil"], (q,), {}, meta)
79+
floorq = self.call_operator(opset["floor"], (q,), {}, meta)
80+
return self.call_operator(opset["where"], (lt0, ceilq, floorq), {}, meta)
81+
82+
raise RuntimeError(
83+
f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}"
84+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def is_node_supported(
176176
exir_ops.edge.aten.hardtanh.default,
177177
exir_ops.edge.aten.hardswish.default,
178178
exir_ops.edge.aten.div.Tensor,
179+
exir_ops.edge.aten.div.Tensor_mode,
179180
exir_ops.edge.aten.eq.Tensor,
180181
exir_ops.edge.aten.eq.Scalar,
181182
exir_ops.edge.aten.erf.default,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# This source code is licensed under the BSD-style license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
from typing import Tuple
6+
7+
import pytest
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineINT,
13+
EthosU85PipelineINT,
14+
TosaPipelineFP,
15+
TosaPipelineINT,
16+
VgfPipeline,
17+
)
18+
19+
input_tt = Tuple[torch.Tensor, torch.Tensor]
20+
21+
22+
def make_float_div_inputs(B: int = 4, T: int = 64) -> input_tt:
23+
x = torch.randn(B, T)
24+
# guard against zero in denominator
25+
y = torch.randn(B, T).abs() + 1e-3
26+
return x, y
27+
28+
29+
class DivTensorModeFloat(torch.nn.Module):
30+
"""
31+
torch.div(x, y, rounding_mode=mode) with
32+
mode from {None, "floor", "trunc"}.
33+
"""
34+
35+
aten_ops = ["aten.div.Tensor_mode"]
36+
aten_ops_int = ["aten.mul.Tensor", "aten.reciprocal.default"]
37+
38+
def __init__(self, mode=None):
39+
super().__init__()
40+
assert mode in (None, "floor", "trunc")
41+
self.mode = mode
42+
43+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
44+
return torch.div(x, y, rounding_mode=self.mode)
45+
46+
47+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
48+
def test_div_tensor_mode_tosa_FP(mode):
49+
50+
model = DivTensorModeFloat(mode)
51+
inputs = make_float_div_inputs()
52+
53+
pipeline = TosaPipelineFP[input_tt](
54+
model,
55+
inputs,
56+
aten_op=model.aten_ops,
57+
exir_op=[],
58+
use_to_edge_transform_and_lower=True,
59+
)
60+
pipeline.pop_stage("check_count.exir")
61+
pipeline.run()
62+
63+
64+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
65+
def test_div_tensor_mode_tosa_INT(mode):
66+
67+
model = DivTensorModeFloat(mode)
68+
inputs = make_float_div_inputs()
69+
70+
pipeline = TosaPipelineINT[input_tt](
71+
model,
72+
inputs,
73+
aten_op=model.aten_ops_int,
74+
exir_op=[],
75+
use_to_edge_transform_and_lower=True,
76+
)
77+
pipeline.pop_stage("check_count.exir")
78+
pipeline.run()
79+
80+
81+
@common.XfailIfNoCorstone300
82+
@pytest.mark.parametrize("mode", [None, "floor"])
83+
def test_div_tensor_mode_u55_INT(mode):
84+
85+
model = DivTensorModeFloat(mode)
86+
inputs = make_float_div_inputs()
87+
88+
pipeline = EthosU55PipelineINT[input_tt](
89+
model,
90+
inputs,
91+
aten_ops=model.aten_ops_int,
92+
exir_ops=[],
93+
use_to_edge_transform_and_lower=True,
94+
run_on_fvp=True,
95+
)
96+
pipeline.run()
97+
98+
99+
@common.XfailIfNoCorstone320
100+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
101+
def test_div_tensor_mode_u85_INT(mode):
102+
103+
model = DivTensorModeFloat(mode)
104+
inputs = make_float_div_inputs()
105+
106+
pipeline = EthosU85PipelineINT[input_tt](
107+
model,
108+
inputs,
109+
aten_ops=model.aten_ops_int,
110+
exir_ops=[],
111+
use_to_edge_transform_and_lower=True,
112+
run_on_fvp=True,
113+
)
114+
pipeline.run()
115+
116+
117+
@common.SkipIfNoModelConverter
118+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
119+
def test_div_tensor_mode_vgf_INT(mode):
120+
121+
model = DivTensorModeFloat(mode)
122+
inputs = make_float_div_inputs()
123+
124+
pipeline = VgfPipeline[input_tt](
125+
model,
126+
inputs,
127+
aten_op=model.aten_ops_int,
128+
exir_op=[],
129+
tosa_version="TOSA-1.0+INT",
130+
use_to_edge_transform_and_lower=True,
131+
)
132+
pipeline.pop_stage("check_count.exir")
133+
pipeline.run()
134+
135+
136+
@common.SkipIfNoModelConverter
137+
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
138+
def test_div_tensor_mode_vgf_FP(mode):
139+
140+
model = DivTensorModeFloat(mode)
141+
inputs = make_float_div_inputs()
142+
143+
pipeline = VgfPipeline[input_tt](
144+
model,
145+
inputs,
146+
aten_op=model.aten_ops,
147+
exir_op=[],
148+
tosa_version="TOSA-1.0+FP",
149+
use_to_edge_transform_and_lower=True,
150+
)
151+
pipeline.run()

0 commit comments

Comments
 (0)