Skip to content

Commit 083663b

Browse files
Arm backend: Add support for le.Scalar (#12107)
le.Scalar nodes are converted into le.Tensor nodes. Add tests for scalar le. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 1b42512 commit 083663b

File tree

5 files changed

+108
-22
lines changed

5 files changed

+108
-22
lines changed

backends/arm/_passes/match_arg_ranks_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, exported_program):
5151
exir_ops.edge.aten.gt.Tensor,
5252
exir_ops.edge.aten.ge.Tensor,
5353
exir_ops.edge.aten.lt.Tensor,
54+
exir_ops.edge.aten.le.Tensor,
5455
exir_ops.edge.aten.pow.Tensor_Tensor,
5556
exir_ops.edge.aten.where.self,
5657
]

backends/arm/_passes/replace_scalar_with_tensor_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
3333
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
3434
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
35+
exir_ops.edge.aten.le.Scalar: exir_ops.edge.aten.le.Tensor,
3536
exir_ops.edge.aten.ne.Scalar: exir_ops.edge.aten.ne.Tensor,
3637
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3738
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
@@ -43,6 +44,7 @@
4344
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
4445
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
4546
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
47+
torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor,
4648
torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor,
4749
}
4850

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class EthosU55NotSupported(OperatorSupportBase):
138138
exir_ops.edge.aten.gt.Tensor,
139139
exir_ops.edge.aten.gt.Scalar,
140140
exir_ops.edge.aten.le.Tensor,
141+
exir_ops.edge.aten.le.Scalar,
141142
exir_ops.edge.aten.lt.Tensor,
142143
exir_ops.edge.aten.lt.Scalar,
143144
exir_ops.edge.aten.ne.Tensor,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def is_node_supported(
189189
exir_ops.edge.aten.gt.Tensor,
190190
exir_ops.edge.aten.gt.Scalar,
191191
exir_ops.edge.aten.le.Tensor,
192+
exir_ops.edge.aten.le.Scalar,
192193
exir_ops.edge.aten.lt.Tensor,
193194
exir_ops.edge.aten.lt.Scalar,
194195
exir_ops.edge.aten.mul.Tensor,

backends/arm/test/ops/test_le.py

Lines changed: 103 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
TosaPipelineMI,
1616
)
1717

18-
aten_op = "torch.ops.aten.le.Tensor"
19-
exir_op = "executorch_exir_dialects_edge__ops_aten_le_Tensor"
2018

2119
input_t = Tuple[torch.Tensor]
2220

2321

24-
class GreaterEqual(torch.nn.Module):
22+
class LessEqual(torch.nn.Module):
23+
aten_op_tensor = "torch.ops.aten.le.Tensor"
24+
aten_op_scalar = "torch.ops.aten.le.Scalar"
25+
exir_op = "executorch_exir_dialects_edge__ops_aten_le_Tensor"
26+
2527
def __init__(self, input, other):
2628
super().__init__()
2729
self.input_ = input
@@ -38,72 +40,151 @@ def get_inputs(self):
3840
return (self.input_, self.other_)
3941

4042

41-
op_le_rank1_ones = GreaterEqual(
43+
op_le_tensor_rank1_ones = LessEqual(
4244
torch.ones(5),
4345
torch.ones(5),
4446
)
45-
op_le_rank2_rand = GreaterEqual(
47+
op_le_tensor_rank2_rand = LessEqual(
4648
torch.rand(4, 5),
4749
torch.rand(1, 5),
4850
)
49-
op_le_rank3_randn = GreaterEqual(
51+
op_le_tensor_rank3_randn = LessEqual(
5052
torch.randn(10, 5, 2),
5153
torch.randn(10, 5, 2),
5254
)
53-
op_le_rank4_randn = GreaterEqual(
55+
op_le_tensor_rank4_randn = LessEqual(
5456
torch.randn(3, 2, 2, 2),
5557
torch.randn(3, 2, 2, 2),
5658
)
5759

58-
test_data_common = {
59-
"le_rank1_ones": lambda: op_le_rank1_ones,
60-
"le_rank2_rand": lambda: op_le_rank2_rand,
61-
"le_rank3_randn": lambda: op_le_rank3_randn,
62-
"le_rank4_randn": lambda: op_le_rank4_randn,
60+
op_le_scalar_rank1_ones = LessEqual(torch.ones(5), 1.0)
61+
op_le_scalar_rank2_rand = LessEqual(torch.rand(4, 5), 0.2)
62+
op_le_scalar_rank3_randn = LessEqual(torch.randn(10, 5, 2), -0.1)
63+
op_le_scalar_rank4_randn = LessEqual(torch.randn(3, 2, 2, 2), 0.3)
64+
65+
test_data_tensor = {
66+
"le_tensor_rank1_ones": lambda: op_le_tensor_rank1_ones,
67+
"le_tensor_rank2_rand": lambda: op_le_tensor_rank2_rand,
68+
"le_tensor_rank3_randn": lambda: op_le_tensor_rank3_randn,
69+
"le_tensor_rank4_randn": lambda: op_le_tensor_rank4_randn,
70+
}
71+
72+
test_data_scalar = {
73+
"le_scalar_rank1_ones": lambda: op_le_scalar_rank1_ones,
74+
"le_scalar_rank2_rand": lambda: op_le_scalar_rank2_rand,
75+
"le_scalar_rank3_randn": lambda: op_le_scalar_rank3_randn,
76+
"le_scalar_rank4_randn": lambda: op_le_scalar_rank4_randn,
6377
}
6478

6579

66-
@common.parametrize("test_module", test_data_common)
80+
@common.parametrize("test_module", test_data_tensor)
6781
def test_le_tensor_tosa_MI(test_module):
6882
pipeline = TosaPipelineMI[input_t](
69-
test_module(), test_module().get_inputs(), aten_op, exir_op
83+
test_module(),
84+
test_module().get_inputs(),
85+
LessEqual.aten_op_tensor,
86+
LessEqual.exir_op,
7087
)
7188
pipeline.run()
7289

7390

74-
@common.parametrize("test_module", test_data_common)
91+
@common.parametrize("test_module", test_data_scalar)
92+
def test_le_scalar_tosa_MI(test_module):
93+
pipeline = TosaPipelineMI[input_t](
94+
test_module(),
95+
test_module().get_inputs(),
96+
LessEqual.aten_op_scalar,
97+
LessEqual.exir_op,
98+
)
99+
pipeline.run()
100+
101+
102+
@common.parametrize("test_module", test_data_tensor)
75103
def test_le_tensor_tosa_BI(test_module):
76104
pipeline = TosaPipelineBI[input_t](
77-
test_module(), test_module().get_inputs(), aten_op, exir_op
105+
test_module(),
106+
test_module().get_inputs(),
107+
LessEqual.aten_op_tensor,
108+
LessEqual.exir_op,
78109
)
79110
pipeline.run()
80111

81112

82-
@common.parametrize("test_module", test_data_common)
113+
@common.parametrize("test_module", test_data_scalar)
114+
def test_le_scalar_tosa_BI(test_module):
115+
pipeline = TosaPipelineBI[input_t](
116+
test_module(),
117+
test_module().get_inputs(),
118+
LessEqual.aten_op_tensor,
119+
LessEqual.exir_op,
120+
)
121+
pipeline.run()
122+
123+
124+
@common.parametrize("test_module", test_data_tensor)
125+
@common.XfailIfNoCorstone300
83126
def test_le_tensor_u55_BI_not_delegated(test_module):
84127
# GREATER_EQUAL is not supported on U55. LE uses the GREATER_EQUAL Tosa operator.
85128
pipeline = OpNotSupportedPipeline[input_t](
86129
test_module(),
87130
test_module().get_inputs(),
88-
{exir_op: 1},
131+
{LessEqual.exir_op: 1},
89132
quantize=True,
90133
u55_subset=True,
91134
)
92135
pipeline.run()
93136

94137

138+
@common.parametrize("test_module", test_data_scalar)
139+
@common.XfailIfNoCorstone300
140+
def test_le_scalar_u55_BI_not_delegated(test_module):
141+
# GREATER_EQUAL is not supported on U55. LE uses the GREATER_EQUAL Tosa operator.
142+
pipeline = OpNotSupportedPipeline[input_t](
143+
test_module(),
144+
test_module().get_inputs(),
145+
{LessEqual.exir_op: 1},
146+
n_expected_delegates=1,
147+
quantize=True,
148+
u55_subset=True,
149+
)
150+
pipeline.dump_operator_distribution("export")
151+
pipeline.run()
152+
153+
95154
@common.parametrize(
96155
"test_module",
97-
test_data_common,
98-
xfails={"le_rank4_randn": "4D fails because boolean Tensors can't be subtracted"},
156+
test_data_tensor,
157+
xfails={
158+
"le_tensor_rank4_randn": "4D fails because boolean Tensors can't be subtracted"
159+
},
99160
)
100161
@common.XfailIfNoCorstone320
101162
def test_le_tensor_u85_BI(test_module):
102163
pipeline = EthosU85PipelineBI[input_t](
103164
test_module(),
104165
test_module().get_inputs(),
105-
aten_op,
106-
exir_op,
166+
LessEqual.aten_op_tensor,
167+
LessEqual.exir_op,
168+
run_on_fvp=True,
169+
use_to_edge_transform_and_lower=True,
170+
)
171+
pipeline.run()
172+
173+
174+
@common.parametrize(
175+
"test_module",
176+
test_data_scalar,
177+
xfails={
178+
"le_scalar_rank4_randn": "4D fails because boolean Tensors can't be subtracted"
179+
},
180+
)
181+
@common.XfailIfNoCorstone320
182+
def test_le_scalar_u85_BI(test_module):
183+
pipeline = EthosU85PipelineBI[input_t](
184+
test_module(),
185+
test_module().get_inputs(),
186+
LessEqual.aten_op_tensor,
187+
LessEqual.exir_op,
107188
run_on_fvp=True,
108189
use_to_edge_transform_and_lower=True,
109190
)

0 commit comments

Comments
 (0)