Skip to content

Commit ec93b9a

Browse files
authored
Revert "Arm backend: Add correction for floor mode (#14776)"
This reverts commit 7d2b8c6.
1 parent 3591604 commit ec93b9a

File tree

2 files changed

+4
-74
lines changed

2 files changed

+4
-74
lines changed

backends/arm/_passes/decompose_div_tensor_mode.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
"full": exir_ops.edge.aten.full.default,
2323
"lt": exir_ops.edge.aten.lt.Tensor,
2424
"where": exir_ops.edge.aten.where.self,
25-
"mul": exir_ops.edge.aten.mul.Tensor,
26-
"sub": exir_ops.edge.aten.sub.Tensor,
2725
}
2826

2927
aten_unary = {
@@ -33,8 +31,6 @@
3331
"full": torch.ops.aten.full.default,
3432
"lt": torch.ops.aten.lt.Tensor,
3533
"where": torch.ops.aten.where.self,
36-
"mul": torch.ops.aten.mul.Tensor,
37-
"sub": torch.ops.aten.sub.Tensor,
3834
}
3935

4036

@@ -74,57 +70,13 @@ def call_operator(self, op, args, kwargs, meta):
7470
return q
7571

7672
if rounding_mode == "floor":
77-
q_raw = q
78-
79-
# trunc(q_raw) = where(q_raw < 0, ceil(q_raw), floor(q_raw))
80-
q_floor = super().call_operator(opset["floor"], (q_raw,), {}, meta)
81-
q_ceil = super().call_operator(opset["ceil"], (q_raw,), {}, meta)
82-
83-
# a zero tensor with the right shape
84-
out_shape = (1,) * len(meta["val"].size())
85-
zero = super().call_operator(
86-
opset["full"],
87-
args=(out_shape, 0.0),
88-
kwargs={},
89-
meta=meta,
90-
)
91-
92-
is_neg = super().call_operator(opset["lt"], (q_raw, zero), {}, meta)
93-
q_trunc = super().call_operator(
94-
opset["where"], (is_neg, q_ceil, q_floor), {}, meta
95-
)
96-
97-
# r = a - q_trunc * b (true remainder under truncation)
98-
q_times_b = super().call_operator(opset["mul"], (q_trunc, b), {}, meta)
99-
r = super().call_operator(opset["sub"], (a, q_times_b), {}, meta)
100-
101-
# Decide if we need to subtract 1:
102-
# for b > 0, adjust if r < 0; for b < 0, adjust if r > 0.
103-
b_pos = super().call_operator(opset["lt"], (zero, b), {}, meta) # b > 0
104-
r_lt0 = super().call_operator(opset["lt"], (r, zero), {}, meta) # r < 0
105-
r_gt0 = super().call_operator(opset["lt"], (zero, r), {}, meta) # r > 0
106-
107-
adjust_if = super().call_operator(
108-
opset["where"], (b_pos, r_lt0, r_gt0), {}, meta
109-
)
110-
111-
one = super().call_operator(
112-
opset["full"],
113-
args=(out_shape, 1.0),
114-
kwargs={},
115-
meta=meta,
116-
)
117-
q_minus_1 = super().call_operator(opset["sub"], (q_trunc, one), {}, meta)
118-
119-
return super().call_operator(
120-
opset["where"], (adjust_if, q_minus_1, q_trunc), {}, meta
121-
)
73+
return super().call_operator(opset["floor"], (q,), {}, meta)
12274

12375
if rounding_mode == "trunc":
12476
zero = super().call_operator(
12577
opset["full"],
12678
args=((1,) * len(meta["val"].size()), 0.0),
127-
kwargs={},
79+
kwargs={"dtype": torch.float32},
12880
meta=meta,
12981
)
13082
lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta)

backends/arm/test/ops/test_div_tensor_mode.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
3636
return torch.div(x, y, rounding_mode=self.mode)
3737

3838

39-
def _rank4_large_randn_case():
40-
torch.manual_seed(0)
41-
x = 200 * torch.randn(5, 10, 25, 20) + 1
42-
torch.manual_seed(1)
43-
y = torch.rand(5, 10, 25, 20) + 1
44-
return x, y
45-
46-
4739
test_data = {
4840
"mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)),
4941
"mode_floor": lambda: (
@@ -55,13 +47,6 @@ def _rank4_large_randn_case():
5547
(torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3),
5648
),
5749
"int_denominator": lambda: (None, (torch.randn(4, 8), 2)),
58-
"op_floor_div_rank4_large_randn": lambda: (
59-
"floor",
60-
(
61-
200 * torch.randn(5, 10, 25, 20) + 1,
62-
torch.rand(5, 10, 25, 20) + 1,
63-
),
64-
),
6550
}
6651

6752

@@ -99,13 +84,7 @@ def test_div_tensor_mode_tosa_INT(data):
9984

10085
@common.XfailIfNoCorstone300
10186
@common.parametrize(
102-
"data",
103-
test_data,
104-
xfails={
105-
"mode_trunc": "CPU op missing in unittests",
106-
"mode_floor": "Not supported",
107-
"op_floor_div_rank4_large_randn": "Not supported",
108-
},
87+
"data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"}
10988
)
11089
def test_div_tensor_mode_u55_INT(data):
11190
mode, inputs = data()
@@ -115,10 +94,9 @@ def test_div_tensor_mode_u55_INT(data):
11594
model,
11695
inputs,
11796
aten_ops=model.aten_ops_int,
97+
exir_ops=[],
11898
use_to_edge_transform_and_lower=True,
11999
)
120-
pipeline.pop_stage("check_not.exir")
121-
pipeline.pop_stage("check_count.exir")
122100
pipeline.run()
123101

124102

0 commit comments

Comments
 (0)