Skip to content

Commit 09b6b43

Browse files
committed
Arm backend: Add support for scalar inputs for floor_divide.default
Signed-off-by: Agrima Khare <[email protected]> Change-Id: Id32fd6df6d7f11801880f7a734d9b43b4fb577dd
1 parent 530e451 commit 09b6b43

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

backends/arm/_passes/decompose_floor_divide_pass.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,15 @@ def get_floor_divide_decomposition(op) -> tuple:
3434
"""
3535

3636
if op in edge_floor_divide_ops:
37-
return (exir_ops.edge.aten.div.Tensor_mode,)
37+
return (
38+
exir_ops.edge.aten.div.Tensor_mode,
39+
exir_ops.edge.aten.full_like.default,
40+
)
3841
if op in aten_floor_divide_ops:
39-
return (torch.ops.aten.div.Tensor_mode,)
42+
return (
43+
torch.ops.aten.div.Tensor_mode,
44+
torch.ops.aten.full_like.default,
45+
)
4046

4147
raise RuntimeError(f"Can't get floor_div decomposition for op {op}")
4248

@@ -52,11 +58,16 @@ def call_operator(self, op, args, kwargs, meta):
5258
if op not in (edge_floor_divide_ops + aten_floor_divide_ops):
5359
return super().call_operator(op, args, kwargs, meta, updated=False)
5460

55-
(div_op,) = get_floor_divide_decomposition(op)
61+
(div_op, full_op) = get_floor_divide_decomposition(op)
5662

5763
input = args[0]
5864
other = args[1]
5965

66+
if isinstance(other, int):
67+
other = super().call_operator(
68+
full_op, (input, other), {}, meta, updated=False
69+
)
70+
6071
div_node = super().call_operator(
6172
div_op, (input, other), {"rounding_mode": "floor"}, meta, updated=True
6273
)

backends/arm/test/ops/test_floor_div.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636
torch.ones(5, 10, 25, 20),
3737
(-1) * torch.ones(5, 10, 25, 20),
3838
),
39-
"op_floor_div_rank4_large_rand": lambda: (
40-
200 * torch.rand(5, 10, 25, 20),
41-
torch.rand(5, 10, 25, 20),
42-
),
4339
"op_floor_div_rank4_randn_mutltiple_broadcasts": lambda: (
4440
torch.randn(1, 4, 4, 1),
4541
torch.randn(1, 1, 4, 4),
@@ -48,6 +44,10 @@
4844
torch.randn(1, 4, 4, 1),
4945
2,
5046
),
47+
"op_floor_div_rank4_large_rand": lambda: (
48+
200 * torch.rand(5, 10, 25, 20),
49+
torch.rand(5, 10, 25, 20),
50+
),
5151
}
5252

5353

0 commit comments

Comments
 (0)