Skip to content

Commit c6e52f3

Browse files
Arm backend: Fix tensor_div_mode decomp for scalar input (#14114)
Signed-off-by: Adrian Lundell <[email protected]>
1 parent 397b8df commit c6e52f3

File tree

2 files changed

+35
-33
lines changed

2 files changed

+35
-33
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
197197
self.add_pass(CastBoolToInt8Pass())
198198
self.add_pass(DecomposeSinhPass())
199199
self.add_pass(DecomposeSignPass())
200+
self.add_pass(DecomposeDivTensorModePass())
200201
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
201202
self.add_pass(DecomposeEmbeddingPass())
202203
self.add_pass(FuseQuantizedActivationPass())
@@ -215,7 +216,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215216
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
216217
)
217218
self.add_pass(DecomposeNotEqualPass())
218-
self.add_pass(DecomposeDivTensorModePass())
219219
self.add_pass(DecomposeDivPass())
220220
self.add_pass(DecomposeSoftmaxPass())
221221
self.add_pass(DecomposeGeluPass())
@@ -285,6 +285,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
285285
self.add_pass(CastBoolToInt8Pass())
286286
self.add_pass(DecomposeSignPass())
287287
self.add_pass(DecomposeAddmmPass())
288+
self.add_pass(DecomposeDivTensorModePass())
288289
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
289290
self.add_pass(ScalarsToAttributePass())
290291
self.add_pass(DecomposeGroupNormPass())
@@ -294,7 +295,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
294295
self.add_pass(DecomposeNotEqualPass())
295296
self.add_pass(DecomposeCosineSimilarityPass())
296297
self.add_pass(DecomposeGluPass())
297-
self.add_pass(DecomposeDivTensorModePass())
298298
self.add_pass(DecomposeDivPass())
299299
self.add_pass(DecomposeLeakyReLUPass())
300300
self.add_pass(DecomposeLinearVectorNormPass())

backends/arm/test/ops/test_div_tensor_mode.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from typing import Tuple
66

7-
import pytest
87
import torch
98

109
from executorch.backends.arm.test import common
@@ -19,13 +18,6 @@
1918
input_tt = Tuple[torch.Tensor, torch.Tensor]
2019

2120

22-
def make_float_div_inputs(B: int = 4, T: int = 64) -> input_tt:
23-
x = torch.randn(B, T)
24-
# guard against zero in denominator
25-
y = torch.randn(B, T).abs() + 1e-3
26-
return x, y
27-
28-
2921
class DivTensorModeFloat(torch.nn.Module):
3022
"""
3123
torch.div(x, y, rounding_mode=mode) with
@@ -44,11 +36,24 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
4436
return torch.div(x, y, rounding_mode=self.mode)
4537

4638

47-
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
48-
def test_div_tensor_mode_tosa_FP(mode):
39+
test_data = {
40+
"mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)),
41+
"mode_floor": lambda: (
42+
"floor",
43+
(torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3),
44+
),
45+
"mode_trunc": lambda: (
46+
"trunc",
47+
(torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3),
48+
),
49+
"int_denominator": lambda: (None, (torch.randn(4, 8), 2)),
50+
}
51+
4952

53+
@common.parametrize("data", test_data)
54+
def test_div_tensor_mode_tosa_FP(data):
55+
mode, inputs = data()
5056
model = DivTensorModeFloat(mode)
51-
inputs = make_float_div_inputs()
5257

5358
pipeline = TosaPipelineFP[input_tt](
5459
model,
@@ -61,11 +66,10 @@ def test_div_tensor_mode_tosa_FP(mode):
6166
pipeline.run()
6267

6368

64-
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
65-
def test_div_tensor_mode_tosa_INT(mode):
66-
69+
@common.parametrize("data", test_data)
70+
def test_div_tensor_mode_tosa_INT(data):
71+
mode, inputs = data()
6772
model = DivTensorModeFloat(mode)
68-
inputs = make_float_div_inputs()
6973

7074
pipeline = TosaPipelineINT[input_tt](
7175
model,
@@ -79,11 +83,12 @@ def test_div_tensor_mode_tosa_INT(mode):
7983

8084

8185
@common.XfailIfNoCorstone300
82-
@pytest.mark.parametrize("mode", [None, "floor"])
83-
def test_div_tensor_mode_u55_INT(mode):
84-
86+
@common.parametrize(
87+
"data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"}
88+
)
89+
def test_div_tensor_mode_u55_INT(data):
90+
mode, inputs = data()
8591
model = DivTensorModeFloat(mode)
86-
inputs = make_float_div_inputs()
8792

8893
pipeline = EthosU55PipelineINT[input_tt](
8994
model,
@@ -97,11 +102,10 @@ def test_div_tensor_mode_u55_INT(mode):
97102

98103

99104
@common.XfailIfNoCorstone320
100-
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
101-
def test_div_tensor_mode_u85_INT(mode):
102-
105+
@common.parametrize("data", test_data)
106+
def test_div_tensor_mode_u85_INT(data):
107+
mode, inputs = data()
103108
model = DivTensorModeFloat(mode)
104-
inputs = make_float_div_inputs()
105109

106110
pipeline = EthosU85PipelineINT[input_tt](
107111
model,
@@ -115,11 +119,10 @@ def test_div_tensor_mode_u85_INT(mode):
115119

116120

117121
@common.SkipIfNoModelConverter
118-
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
119-
def test_div_tensor_mode_vgf_INT(mode):
120-
122+
@common.parametrize("data", test_data)
123+
def test_div_tensor_mode_vgf_INT(data):
124+
mode, inputs = data()
121125
model = DivTensorModeFloat(mode)
122-
inputs = make_float_div_inputs()
123126

124127
pipeline = VgfPipeline[input_tt](
125128
model,
@@ -134,11 +137,10 @@ def test_div_tensor_mode_vgf_INT(mode):
134137

135138

136139
@common.SkipIfNoModelConverter
137-
@pytest.mark.parametrize("mode", [None, "floor", "trunc"])
138-
def test_div_tensor_mode_vgf_FP(mode):
139-
140+
@common.parametrize("data", test_data)
141+
def test_div_tensor_mode_vgf_FP(data):
142+
mode, inputs = data()
140143
model = DivTensorModeFloat(mode)
141-
inputs = make_float_div_inputs()
142144

143145
pipeline = VgfPipeline[input_tt](
144146
model,

0 commit comments

Comments
 (0)