55
66from typing import Tuple
77
8- import pytest
98import torch
109from executorch .backends .arm .test import common
1110
1615 TosaPipelineMI ,
1716)
1817
19- aten_op = "torch.ops.aten.eq.Tensor"
20- exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
2118
2219input_t = Tuple [torch .Tensor ]
2320
2421
2522class Equal (torch .nn .Module ):
23+ aten_op_BI = "torch.ops.aten.eq.Tensor"
24+ aten_op_MI = "torch.ops.aten.eq.Scalar"
25+ exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
26+
2627 def __init__ (self , input , other ):
2728 super ().__init__ ()
2829 self .input_ = input
@@ -31,106 +32,119 @@ def __init__(self, input, other):
3132 def forward (
3233 self ,
3334 input_ : torch .Tensor ,
34- other_ : torch .Tensor ,
35+ other_ : torch .Tensor | int | float ,
3536 ):
3637 return input_ == other_
3738
3839 def get_inputs (self ):
3940 return (self .input_ , self .other_ )
4041
4142
42- op_eq_rank1_ones = Equal (
43+ op_eq_tensor_rank1_ones = Equal (
4344 torch .ones (5 ),
4445 torch .ones (5 ),
4546)
46- op_eq_rank2_rand = Equal (
47+ op_eq_tensor_rank2_rand = Equal (
4748 torch .rand (4 , 5 ),
4849 torch .rand (1 , 5 ),
4950)
50- op_eq_rank3_randn = Equal (
51+ op_eq_tensor_rank3_randn = Equal (
5152 torch .randn (10 , 5 , 2 ),
5253 torch .randn (10 , 5 , 2 ),
5354)
54- op_eq_rank4_randn = Equal (
55+ op_eq_tensor_rank4_randn = Equal (
5556 torch .randn (3 , 2 , 2 , 2 ),
5657 torch .randn (3 , 2 , 2 , 2 ),
5758)
5859
59- test_data_common = {
60- "eq_rank1_ones" : op_eq_rank1_ones ,
61- "eq_rank2_rand" : op_eq_rank2_rand ,
62- "eq_rank3_randn" : op_eq_rank3_randn ,
63- "eq_rank4_randn" : op_eq_rank4_randn ,
60+ op_eq_scalar_rank1_ones = Equal (torch .ones (5 ), 1.0 )
61+ op_eq_scalar_rank2_rand = Equal (torch .rand (4 , 5 ), 0.2 )
62+ op_eq_scalar_rank3_randn = Equal (torch .randn (10 , 5 , 2 ), - 0.1 )
63+ op_eq_scalar_rank4_randn = Equal (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
64+
65+ test_data_tensor = {
66+ "eq_tensor_rank1_ones" : op_eq_tensor_rank1_ones ,
67+ "eq_tensor_rank2_rand" : op_eq_tensor_rank2_rand ,
68+ "eq_tensor_rank3_randn" : op_eq_tensor_rank3_randn ,
69+ "eq_tensor_rank4_randn" : op_eq_tensor_rank4_randn ,
6470}
6571
72+ test_data_scalar = {
73+ "eq_scalar_rank1_ones" : op_eq_scalar_rank1_ones ,
74+ "eq_scalar_rank2_rand" : op_eq_scalar_rank2_rand ,
75+ "eq_scalar_rank3_randn" : op_eq_scalar_rank3_randn ,
76+ "eq_scalar_rank4_randn" : op_eq_scalar_rank4_randn ,
77+ }
78+
79+
80+ @common .parametrize ("test_module" , test_data_tensor )
81+ def test_eq_tensor_tosa_MI (test_module ):
82+ pipeline = TosaPipelineMI [input_t ](
83+ test_module , test_module .get_inputs (), Equal .aten_op_BI , Equal .exir_op
84+ )
85+ pipeline .run ()
6686
67- @common .parametrize ("test_module" , test_data_common )
68- def test_eq_tosa_MI (test_module ):
87+
88+ @common .parametrize ("test_module" , test_data_scalar )
89+ def test_eq_scalar_tosa_MI (test_module ):
6990 pipeline = TosaPipelineMI [input_t ](
70- test_module , test_module .get_inputs (), aten_op , exir_op
91+ test_module ,
92+ test_module .get_inputs (),
93+ Equal .aten_op_MI ,
94+ Equal .exir_op ,
7195 )
7296 pipeline .run ()
7397
7498
75- @common .parametrize ("test_module" , test_data_common )
99+ @common .parametrize ("test_module" , test_data_tensor | test_data_scalar )
76100def test_eq_tosa_BI (test_module ):
77101 pipeline = TosaPipelineBI [input_t ](
78- test_module , test_module .get_inputs (), aten_op , exir_op
102+ test_module , test_module .get_inputs (), Equal . aten_op_BI , Equal . exir_op
79103 )
80104 pipeline .run ()
81105
82106
83- @common .parametrize ("test_module" , test_data_common )
84- def test_eq_u55_BI (test_module ):
107+ @common .parametrize ("test_module" , test_data_tensor )
108+ @common .XfailIfNoCorstone300
109+ def test_eq_tensor_u55_BI (test_module ):
85110 # EQUAL is not supported on U55.
86111 pipeline = OpNotSupportedPipeline [input_t ](
87112 test_module ,
88113 test_module .get_inputs (),
89114 "TOSA-0.80+BI+u55" ,
90- {exir_op : 1 },
91- )
92- pipeline .run ()
93-
94-
95- @common .parametrize ("test_module" , test_data_common )
96- def test_eq_u85_BI (test_module ):
97- pipeline = EthosU85PipelineBI [input_t ](
98- test_module ,
99- test_module .get_inputs (),
100- aten_op ,
101- exir_op ,
102- run_on_fvp = False ,
103- use_to_edge_transform_and_lower = True ,
115+ {Equal .exir_op : 1 },
104116 )
105117 pipeline .run ()
106118
107119
108- @common .parametrize ("test_module" , test_data_common )
109- @pytest . mark . skip ( reason = "The same as test_eq_u55_BI" )
110- def test_eq_u55_BI_on_fvp (test_module ):
120+ @common .parametrize ("test_module" , test_data_scalar )
121+ @common . XfailIfNoCorstone300
122+ def test_eq_scalar_u55_BI (test_module ):
111123 # EQUAL is not supported on U55.
112124 pipeline = OpNotSupportedPipeline [input_t ](
113125 test_module ,
114126 test_module .get_inputs (),
115127 "TOSA-0.80+BI+u55" ,
116- {exir_op : 1 },
128+ {Equal .exir_op : 1 },
129+ n_expected_delegates = 1 ,
117130 )
118131 pipeline .run ()
119132
120133
121134@common .parametrize (
122135 "test_module" ,
123- test_data_common ,
124- xfails = {"eq_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" },
136+ test_data_tensor | test_data_scalar ,
137+ xfails = {
138+ "eq_tensor_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" ,
139+ },
125140)
126- @common .SkipIfNoCorstone320
127- def test_eq_u85_BI_on_fvp (test_module ):
141+ @common .XfailIfNoCorstone320
142+ def test_eq_u85_BI (test_module ):
128143 pipeline = EthosU85PipelineBI [input_t ](
129144 test_module ,
130145 test_module .get_inputs (),
131- aten_op ,
132- exir_op ,
146+ Equal . aten_op_BI ,
147+ Equal . exir_op ,
133148 run_on_fvp = True ,
134- use_to_edge_transform_and_lower = True ,
135149 )
136150 pipeline .run ()
0 commit comments