1515 TosaPipelineMI ,
1616)
1717
18- aten_op = "torch.ops.aten.le.Tensor"
19- exir_op = "executorch_exir_dialects_edge__ops_aten_le_Tensor"
2018
2119input_t = Tuple [torch .Tensor ]
2220
2321
24- class GreaterEqual (torch .nn .Module ):
22+ class LessEqual (torch .nn .Module ):
23+ aten_op_tensor = "torch.ops.aten.le.Tensor"
24+ aten_op_scalar = "torch.ops.aten.le.Scalar"
25+ exir_op = "executorch_exir_dialects_edge__ops_aten_le_Tensor"
26+
2527 def __init__ (self , input , other ):
2628 super ().__init__ ()
2729 self .input_ = input
@@ -38,72 +40,151 @@ def get_inputs(self):
3840 return (self .input_ , self .other_ )
3941
4042
41- op_le_rank1_ones = GreaterEqual (
43+ op_le_tensor_rank1_ones = LessEqual (
4244 torch .ones (5 ),
4345 torch .ones (5 ),
4446)
45- op_le_rank2_rand = GreaterEqual (
47+ op_le_tensor_rank2_rand = LessEqual (
4648 torch .rand (4 , 5 ),
4749 torch .rand (1 , 5 ),
4850)
49- op_le_rank3_randn = GreaterEqual (
51+ op_le_tensor_rank3_randn = LessEqual (
5052 torch .randn (10 , 5 , 2 ),
5153 torch .randn (10 , 5 , 2 ),
5254)
53- op_le_rank4_randn = GreaterEqual (
55+ op_le_tensor_rank4_randn = LessEqual (
5456 torch .randn (3 , 2 , 2 , 2 ),
5557 torch .randn (3 , 2 , 2 , 2 ),
5658)
5759
58- test_data_common = {
59- "le_rank1_ones" : lambda : op_le_rank1_ones ,
60- "le_rank2_rand" : lambda : op_le_rank2_rand ,
61- "le_rank3_randn" : lambda : op_le_rank3_randn ,
62- "le_rank4_randn" : lambda : op_le_rank4_randn ,
60+ op_le_scalar_rank1_ones = LessEqual (torch .ones (5 ), 1.0 )
61+ op_le_scalar_rank2_rand = LessEqual (torch .rand (4 , 5 ), 0.2 )
62+ op_le_scalar_rank3_randn = LessEqual (torch .randn (10 , 5 , 2 ), - 0.1 )
63+ op_le_scalar_rank4_randn = LessEqual (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
64+
65+ test_data_tensor = {
66+ "le_tensor_rank1_ones" : lambda : op_le_tensor_rank1_ones ,
67+ "le_tensor_rank2_rand" : lambda : op_le_tensor_rank2_rand ,
68+ "le_tensor_rank3_randn" : lambda : op_le_tensor_rank3_randn ,
69+ "le_tensor_rank4_randn" : lambda : op_le_tensor_rank4_randn ,
70+ }
71+
72+ test_data_scalar = {
73+ "le_scalar_rank1_ones" : lambda : op_le_scalar_rank1_ones ,
74+ "le_scalar_rank2_rand" : lambda : op_le_scalar_rank2_rand ,
75+ "le_scalar_rank3_randn" : lambda : op_le_scalar_rank3_randn ,
76+ "le_scalar_rank4_randn" : lambda : op_le_scalar_rank4_randn ,
6377}
6478
6579
66- @common .parametrize ("test_module" , test_data_common )
80+ @common .parametrize ("test_module" , test_data_tensor )
6781def test_le_tensor_tosa_MI (test_module ):
6882 pipeline = TosaPipelineMI [input_t ](
69- test_module (), test_module ().get_inputs (), aten_op , exir_op
83+ test_module (),
84+ test_module ().get_inputs (),
85+ LessEqual .aten_op_tensor ,
86+ LessEqual .exir_op ,
7087 )
7188 pipeline .run ()
7289
7390
74- @common .parametrize ("test_module" , test_data_common )
91+ @common .parametrize ("test_module" , test_data_scalar )
92+ def test_le_scalar_tosa_MI (test_module ):
93+ pipeline = TosaPipelineMI [input_t ](
94+ test_module (),
95+ test_module ().get_inputs (),
96+ LessEqual .aten_op_scalar ,
97+ LessEqual .exir_op ,
98+ )
99+ pipeline .run ()
100+
101+
102+ @common .parametrize ("test_module" , test_data_tensor )
75103def test_le_tensor_tosa_BI (test_module ):
76104 pipeline = TosaPipelineBI [input_t ](
77- test_module (), test_module ().get_inputs (), aten_op , exir_op
105+ test_module (),
106+ test_module ().get_inputs (),
107+ LessEqual .aten_op_tensor ,
108+ LessEqual .exir_op ,
78109 )
79110 pipeline .run ()
80111
81112
82- @common .parametrize ("test_module" , test_data_common )
113+ @common .parametrize ("test_module" , test_data_scalar )
114+ def test_le_scalar_tosa_BI (test_module ):
115+ pipeline = TosaPipelineBI [input_t ](
116+ test_module (),
117+ test_module ().get_inputs (),
118+ LessEqual .aten_op_tensor ,
119+ LessEqual .exir_op ,
120+ )
121+ pipeline .run ()
122+
123+
124+ @common .parametrize ("test_module" , test_data_tensor )
125+ @common .XfailIfNoCorstone300
83126def test_le_tensor_u55_BI_not_delegated (test_module ):
84127 # GREATER_EQUAL is not supported on U55. LE uses the GREATER_EQUAL Tosa operator.
85128 pipeline = OpNotSupportedPipeline [input_t ](
86129 test_module (),
87130 test_module ().get_inputs (),
88- {exir_op : 1 },
131+ {LessEqual . exir_op : 1 },
89132 quantize = True ,
90133 u55_subset = True ,
91134 )
92135 pipeline .run ()
93136
94137
138+ @common .parametrize ("test_module" , test_data_scalar )
139+ @common .XfailIfNoCorstone300
140+ def test_le_scalar_u55_BI_not_delegated (test_module ):
141+ # GREATER_EQUAL is not supported on U55. LE uses the GREATER_EQUAL Tosa operator.
142+ pipeline = OpNotSupportedPipeline [input_t ](
143+ test_module (),
144+ test_module ().get_inputs (),
145+ {LessEqual .exir_op : 1 },
146+ n_expected_delegates = 1 ,
147+ quantize = True ,
148+ u55_subset = True ,
149+ )
150+ pipeline .dump_operator_distribution ("export" )
151+ pipeline .run ()
152+
153+
95154@common .parametrize (
96155 "test_module" ,
97- test_data_common ,
98- xfails = {"le_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" },
156+ test_data_tensor ,
157+ xfails = {
158+ "le_tensor_rank4_randn" : "4D fails because boolean Tensors can't be subtracted"
159+ },
99160)
100161@common .XfailIfNoCorstone320
101162def test_le_tensor_u85_BI (test_module ):
102163 pipeline = EthosU85PipelineBI [input_t ](
103164 test_module (),
104165 test_module ().get_inputs (),
105- aten_op ,
106- exir_op ,
166+ LessEqual .aten_op_tensor ,
167+ LessEqual .exir_op ,
168+ run_on_fvp = True ,
169+ use_to_edge_transform_and_lower = True ,
170+ )
171+ pipeline .run ()
172+
173+
174+ @common .parametrize (
175+ "test_module" ,
176+ test_data_scalar ,
177+ xfails = {
178+ "le_scalar_rank4_randn" : "4D fails because boolean Tensors can't be subtracted"
179+ },
180+ )
181+ @common .XfailIfNoCorstone320
182+ def test_le_scalar_u85_BI (test_module ):
183+ pipeline = EthosU85PipelineBI [input_t ](
184+ test_module (),
185+ test_module ().get_inputs (),
186+ LessEqual .aten_op_tensor ,
187+ LessEqual .exir_op ,
107188 run_on_fvp = True ,
108189 use_to_edge_transform_and_lower = True ,
109190 )
0 commit comments