diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index 5ad348806e3..b5352475d51 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -22,8 +22,6 @@ "full": exir_ops.edge.aten.full.default, "lt": exir_ops.edge.aten.lt.Tensor, "where": exir_ops.edge.aten.where.self, - "mul": exir_ops.edge.aten.mul.Tensor, - "sub": exir_ops.edge.aten.sub.Tensor, } aten_unary = { @@ -33,8 +31,6 @@ "full": torch.ops.aten.full.default, "lt": torch.ops.aten.lt.Tensor, "where": torch.ops.aten.where.self, - "mul": torch.ops.aten.mul.Tensor, - "sub": torch.ops.aten.sub.Tensor, } @@ -74,57 +70,13 @@ def call_operator(self, op, args, kwargs, meta): return q if rounding_mode == "floor": - q_raw = q - - # trunc(q_raw) = where(q_raw < 0, ceil(q_raw), floor(q_raw)) - q_floor = super().call_operator(opset["floor"], (q_raw,), {}, meta) - q_ceil = super().call_operator(opset["ceil"], (q_raw,), {}, meta) - - # a zero tensor with the right shape - out_shape = (1,) * len(meta["val"].size()) - zero = super().call_operator( - opset["full"], - args=(out_shape, 0.0), - kwargs={}, - meta=meta, - ) - - is_neg = super().call_operator(opset["lt"], (q_raw, zero), {}, meta) - q_trunc = super().call_operator( - opset["where"], (is_neg, q_ceil, q_floor), {}, meta - ) - - # r = a - q_trunc * b (true remainder under truncation) - q_times_b = super().call_operator(opset["mul"], (q_trunc, b), {}, meta) - r = super().call_operator(opset["sub"], (a, q_times_b), {}, meta) - - # Decide if we need to subtract 1: - # for b > 0, adjust if r < 0; for b < 0, adjust if r > 0. - b_pos = super().call_operator(opset["lt"], (zero, b), {}, meta) # b > 0 - r_lt0 = super().call_operator(opset["lt"], (r, zero), {}, meta) # r < 0 - r_gt0 = super().call_operator(opset["lt"], (zero, r), {}, meta) # r > 0 - - adjust_if = super().call_operator( - opset["where"], (b_pos, r_lt0, r_gt0), {}, meta - ) - - one = super().call_operator( - opset["full"], - args=(out_shape, 1.0), - kwargs={}, - meta=meta, - ) - q_minus_1 = super().call_operator(opset["sub"], (q_trunc, one), {}, meta) - - return super().call_operator( - opset["where"], (adjust_if, q_minus_1, q_trunc), {}, meta - ) + return super().call_operator(opset["floor"], (q,), {}, meta) if rounding_mode == "trunc": zero = super().call_operator( opset["full"], args=((1,) * len(meta["val"].size()), 0.0), - kwargs={}, + kwargs={"dtype": torch.float32}, meta=meta, ) lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index 9057be343f1..e1f6036a487 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -36,14 +36,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.div(x, y, rounding_mode=self.mode) -def _rank4_large_randn_case(): - torch.manual_seed(0) - x = 200 * torch.randn(5, 10, 25, 20) + 1 - torch.manual_seed(1) - y = torch.rand(5, 10, 25, 20) + 1 - return x, y - - test_data = { "mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)), "mode_floor": lambda: ( @@ -55,13 +47,6 @@ def _rank4_large_randn_case(): (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3), ), "int_denominator": lambda: (None, (torch.randn(4, 8), 2)), - "op_floor_div_rank4_large_randn": lambda: ( - "floor", - ( - 200 * torch.randn(5, 10, 25, 20) + 1, - torch.rand(5, 10, 25, 20) + 1, - ), - ), } @@ -99,13 +84,7 @@ def test_div_tensor_mode_tosa_INT(data): @common.XfailIfNoCorstone300 @common.parametrize( - "data", - test_data, - xfails={ - "mode_trunc": "CPU op missing in unittests", - "mode_floor": "Not supported", - "op_floor_div_rank4_large_randn": "Not supported", - }, + "data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"} ) def test_div_tensor_mode_u55_INT(data): mode, inputs = data() @@ -115,10 +94,9 @@ def test_div_tensor_mode_u55_INT(data): model, inputs, aten_ops=model.aten_ops_int, + exir_ops=[], use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check_not.exir") - pipeline.pop_stage("check_count.exir") pipeline.run()