diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index b5352475d51..5ad348806e3 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -22,6 +22,8 @@ "full": exir_ops.edge.aten.full.default, "lt": exir_ops.edge.aten.lt.Tensor, "where": exir_ops.edge.aten.where.self, + "mul": exir_ops.edge.aten.mul.Tensor, + "sub": exir_ops.edge.aten.sub.Tensor, } aten_unary = { @@ -31,6 +33,8 @@ "full": torch.ops.aten.full.default, "lt": torch.ops.aten.lt.Tensor, "where": torch.ops.aten.where.self, + "mul": torch.ops.aten.mul.Tensor, + "sub": torch.ops.aten.sub.Tensor, } @@ -70,13 +74,57 @@ def call_operator(self, op, args, kwargs, meta): return q if rounding_mode == "floor": - return super().call_operator(opset["floor"], (q,), {}, meta) + q_raw = q + + # trunc(q_raw) = where(q_raw < 0, ceil(q_raw), floor(q_raw)) + q_floor = super().call_operator(opset["floor"], (q_raw,), {}, meta) + q_ceil = super().call_operator(opset["ceil"], (q_raw,), {}, meta) + + # a zero tensor with the right shape + out_shape = (1,) * len(meta["val"].size()) + zero = super().call_operator( + opset["full"], + args=(out_shape, 0.0), + kwargs={}, + meta=meta, + ) + + is_neg = super().call_operator(opset["lt"], (q_raw, zero), {}, meta) + q_trunc = super().call_operator( + opset["where"], (is_neg, q_ceil, q_floor), {}, meta + ) + + # r = a - q_trunc * b (true remainder under truncation) + q_times_b = super().call_operator(opset["mul"], (q_trunc, b), {}, meta) + r = super().call_operator(opset["sub"], (a, q_times_b), {}, meta) + + # Decide if we need to subtract 1: + # for b > 0, adjust if r < 0; for b < 0, adjust if r > 0. + b_pos = super().call_operator(opset["lt"], (zero, b), {}, meta) # b > 0 + r_lt0 = super().call_operator(opset["lt"], (r, zero), {}, meta) # r < 0 + r_gt0 = super().call_operator(opset["lt"], (zero, r), {}, meta) # r > 0 + + adjust_if = super().call_operator( + opset["where"], (b_pos, r_lt0, r_gt0), {}, meta + ) + + one = super().call_operator( + opset["full"], + args=(out_shape, 1.0), + kwargs={}, + meta=meta, + ) + q_minus_1 = super().call_operator(opset["sub"], (q_trunc, one), {}, meta) + + return super().call_operator( + opset["where"], (adjust_if, q_minus_1, q_trunc), {}, meta + ) if rounding_mode == "trunc": zero = super().call_operator( opset["full"], args=((1,) * len(meta["val"].size()), 0.0), - kwargs={"dtype": torch.float32}, + kwargs={}, meta=meta, ) lt0 = self.call_operator(opset["lt"], (q, zero), {}, meta) diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index e1f6036a487..9057be343f1 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -36,6 +36,14 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.div(x, y, rounding_mode=self.mode) +def _rank4_large_randn_case(): + torch.manual_seed(0) + x = 200 * torch.randn(5, 10, 25, 20) + 1 + torch.manual_seed(1) + y = torch.rand(5, 10, 25, 20) + 1 + return x, y + + test_data = { "mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)), "mode_floor": lambda: ( @@ -47,6 +55,13 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3), ), "int_denominator": lambda: (None, (torch.randn(4, 8), 2)), + "op_floor_div_rank4_large_randn": lambda: ( + "floor", + ( + 200 * torch.randn(5, 10, 25, 20) + 1, + torch.rand(5, 10, 25, 20) + 1, + ), + ), } @@ -84,7 +99,13 @@ def test_div_tensor_mode_tosa_INT(data): @common.XfailIfNoCorstone300 @common.parametrize( - "data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"} + "data", + test_data, + xfails={ + "mode_trunc": "CPU op missing in unittests", + "mode_floor": "Not supported", + "op_floor_div_rank4_large_randn": "Not supported", + }, ) def test_div_tensor_mode_u55_INT(data): mode, inputs = data() @@ -94,9 +115,10 @@ def test_div_tensor_mode_u55_INT(data): model, inputs, aten_ops=model.aten_ops_int, - exir_ops=[], use_to_edge_transform_and_lower=True, ) + pipeline.pop_stage("check_not.exir") + pipeline.pop_stage("check_count.exir") pipeline.run()