55
66from typing import Tuple
77
8- import pytest
98import torch
109from executorch .backends .arm .test import common
1110
1615 TosaPipelineMI ,
1716)
1817
19- aten_op = "torch.ops.aten.ge.Tensor"
20- exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor"
21-
2218input_t = Tuple [torch .Tensor ]
2319
2420
2521class GreaterEqual (torch .nn .Module ):
22+ aten_op_tensor = "torch.ops.aten.ge.Tensor"
23+ aten_op_scalar = "torch.ops.aten.ge.Scalar"
24+ exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor"
25+
2626 def __init__ (self , input , other ):
2727 super ().__init__ ()
2828 self .input_ = input
@@ -31,106 +31,151 @@ def __init__(self, input, other):
3131 def forward (
3232 self ,
3333 input_ : torch .Tensor ,
34- other_ : torch .Tensor ,
34+ other_ : torch .Tensor | int | float ,
3535 ):
3636 return input_ >= other_
3737
3838 def get_inputs (self ):
3939 return (self .input_ , self .other_ )
4040
4141
42- op_ge_rank1_ones = GreaterEqual (
42+ op_ge_tensor_rank1_ones = GreaterEqual (
4343 torch .ones (5 ),
4444 torch .ones (5 ),
4545)
46- op_ge_rank2_rand = GreaterEqual (
46+ op_ge_tensor_rank2_rand = GreaterEqual (
4747 torch .rand (4 , 5 ),
4848 torch .rand (1 , 5 ),
4949)
50- op_ge_rank3_randn = GreaterEqual (
50+ op_ge_tensor_rank3_randn = GreaterEqual (
5151 torch .randn (10 , 5 , 2 ),
5252 torch .randn (10 , 5 , 2 ),
5353)
54- op_ge_rank4_randn = GreaterEqual (
54+ op_ge_tensor_rank4_randn = GreaterEqual (
5555 torch .randn (3 , 2 , 2 , 2 ),
5656 torch .randn (3 , 2 , 2 , 2 ),
5757)
5858
59- test_data_common = {
60- "ge_rank1_ones" : op_ge_rank1_ones ,
61- "ge_rank2_rand" : op_ge_rank2_rand ,
62- "ge_rank3_randn" : op_ge_rank3_randn ,
63- "ge_rank4_randn" : op_ge_rank4_randn ,
59+ op_ge_scalar_rank1_ones = GreaterEqual (torch .ones (5 ), 1.0 )
60+ op_ge_scalar_rank2_rand = GreaterEqual (torch .rand (4 , 5 ), 0.2 )
61+ op_ge_scalar_rank3_randn = GreaterEqual (torch .randn (10 , 5 , 2 ), - 0.1 )
62+ op_ge_scalar_rank4_randn = GreaterEqual (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
63+
64+ test_data_tensor = {
65+ "ge_tensor_rank1_ones" : op_ge_tensor_rank1_ones ,
66+ "ge_tensor_rank2_rand" : op_ge_tensor_rank2_rand ,
67+ "ge_tensor_rank3_randn" : op_ge_tensor_rank3_randn ,
68+ "ge_tensor_rank4_randn" : op_ge_tensor_rank4_randn ,
69+ }
70+
71+ test_data_scalar = {
72+ "ge_scalar_rank1_ones" : op_ge_scalar_rank1_ones ,
73+ "ge_scalar_rank2_rand" : op_ge_scalar_rank2_rand ,
74+ "ge_scalar_rank3_randn" : op_ge_scalar_rank3_randn ,
75+ "ge_scalar_rank4_randn" : op_ge_scalar_rank4_randn ,
6476}
6577
6678
67- @common .parametrize ("test_module" , test_data_common )
68- def test_ge_tosa_MI (test_module ):
79+ @common .parametrize ("test_module" , test_data_tensor )
80+ def test_ge_tensor_tosa_MI (test_module ):
81+ pipeline = TosaPipelineMI [input_t ](
82+ test_module ,
83+ test_module .get_inputs (),
84+ GreaterEqual .aten_op_tensor ,
85+ GreaterEqual .exir_op ,
86+ )
87+ pipeline .run ()
88+
89+
90+ @common .parametrize ("test_module" , test_data_scalar )
91+ def test_ge_scalar_tosa_MI (test_module ):
6992 pipeline = TosaPipelineMI [input_t ](
70- test_module , test_module .get_inputs (), aten_op , exir_op
93+ test_module ,
94+ test_module .get_inputs (),
95+ GreaterEqual .aten_op_scalar ,
96+ GreaterEqual .exir_op ,
7197 )
7298 pipeline .run ()
7399
74100
75- @common .parametrize ("test_module" , test_data_common )
76- def test_ge_tosa_BI (test_module ):
101+ @common .parametrize ("test_module" , test_data_tensor )
102+ def test_ge_tensor_tosa_BI (test_module ):
77103 pipeline = TosaPipelineBI [input_t ](
78- test_module , test_module .get_inputs (), aten_op , exir_op
104+ test_module ,
105+ test_module .get_inputs (),
106+ GreaterEqual .aten_op_tensor ,
107+ GreaterEqual .exir_op ,
79108 )
80109 pipeline .run ()
81110
82111
83- @common .parametrize ("test_module" , test_data_common )
84- def test_ge_u55_BI (test_module ):
85- # GREATER_EQUAL is not supported on U55.
86- pipeline = OpNotSupportedPipeline [input_t ](
112+ @common .parametrize ("test_module" , test_data_scalar )
113+ def test_ge_scalar_tosa_BI (test_module ):
114+ pipeline = TosaPipelineBI [input_t ](
87115 test_module ,
88116 test_module .get_inputs (),
89- "TOSA-0.80+BI+u55" ,
90- { exir_op : 1 } ,
117+ GreaterEqual . aten_op_tensor ,
118+ GreaterEqual . exir_op ,
91119 )
92120 pipeline .run ()
93121
94122
95- @common .parametrize ("test_module" , test_data_common )
96- def test_ge_u85_BI (test_module ):
97- pipeline = EthosU85PipelineBI [input_t ](
123+ @common .parametrize ("test_module" , test_data_tensor )
124+ @common .XfailIfNoCorstone300
125+ def test_ge_tensor_u55_BI (test_module ):
126+ # GREATER_EQUAL is not supported on U55.
127+ pipeline = OpNotSupportedPipeline [input_t ](
98128 test_module ,
99129 test_module .get_inputs (),
100- aten_op ,
101- exir_op ,
102- run_on_fvp = False ,
103- use_to_edge_transform_and_lower = True ,
130+ "TOSA-0.80+BI+u55" ,
131+ {GreaterEqual .exir_op : 1 },
104132 )
105133 pipeline .run ()
106134
107135
108- @common .parametrize ("test_module" , test_data_common )
109- @pytest . mark . skip ( reason = "The same as test_ge_u55_BI" )
110- def test_ge_u55_BI_on_fvp (test_module ):
136+ @common .parametrize ("test_module" , test_data_scalar )
137+ @common . XfailIfNoCorstone300
138+ def test_ge_scalar_u55_BI (test_module ):
111139 # GREATER_EQUAL is not supported on U55.
112140 pipeline = OpNotSupportedPipeline [input_t ](
113141 test_module ,
114142 test_module .get_inputs (),
115143 "TOSA-0.80+BI+u55" ,
116- {exir_op : 1 },
144+ {GreaterEqual .exir_op : 1 },
145+ n_expected_delegates = 1 ,
146+ )
147+ pipeline .run ()
148+
149+
150+ @common .parametrize (
151+ "test_module" ,
152+ test_data_tensor ,
153+ xfails = {"ge_tensor_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" },
154+ )
155+ @common .XfailIfNoCorstone320
156+ def test_ge_tensor_u85_BI (test_module ):
157+ pipeline = EthosU85PipelineBI [input_t ](
158+ test_module ,
159+ test_module .get_inputs (),
160+ GreaterEqual .aten_op_tensor ,
161+ GreaterEqual .exir_op ,
162+ run_on_fvp = True ,
117163 )
118164 pipeline .run ()
119165
120166
121167@common .parametrize (
122168 "test_module" ,
123- test_data_common ,
124- xfails = {"ge_rank4_randn " : "4D fails because boolean Tensors can't be subtracted " },
169+ test_data_scalar ,
170+ xfails = {"ge_scalar_rank4_randn " : "MLETORCH-847: Boolean eq result unstable on U85 " },
125171)
126- @common .SkipIfNoCorstone320
127- def test_ge_u85_BI_on_fvp (test_module ):
172+ @common .XfailIfNoCorstone320
173+ def test_ge_scalar_u85_BI (test_module ):
128174 pipeline = EthosU85PipelineBI [input_t ](
129175 test_module ,
130176 test_module .get_inputs (),
131- aten_op ,
132- exir_op ,
177+ GreaterEqual . aten_op_tensor ,
178+ GreaterEqual . exir_op ,
133179 run_on_fvp = True ,
134- use_to_edge_transform_and_lower = True ,
135180 )
136181 pipeline .run ()
0 commit comments