@@ -24,7 +24,12 @@ def _nonzero_float_tensor(*shape: int) -> torch.Tensor:
2424class Remainder (torch .nn .Module ):
2525 input_t = Tuple [torch .Tensor | float , torch .Tensor | float ]
2626
27- test_cases = {
27+ aten_op_tensor = "torch.ops.aten.remainder.Tensor"
28+ exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"
29+ aten_op_scalar = "torch.ops.aten.remainder.Scalar"
30+ exir_op_scalar = "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
31+
32+ test_cases_tensor = {
2833 "rank2_tensors" : lambda : (
2934 torch .randn (2 , 3 ) * 7 ,
3035 _nonzero_float_tensor (2 , 3 ),
@@ -37,44 +42,49 @@ class Remainder(torch.nn.Module):
3742 torch .randn (4 , 5 , 1 ),
3843 _nonzero_float_tensor (1 , 5 , 6 ),
3944 ),
40- "scalar_rhs" : lambda : (
45+ }
46+
47+ test_cases_scalar = {
48+ "scalar_pos" : lambda : (
4149 torch .randn (1 , 2 , 3 , 4 ),
4250 0.25 ,
4351 ),
52+ "scalar_neg" : lambda : (
53+ torch .randn (3 , 4 ),
54+ - 0.25 ,
55+ ),
4456 }
4557
4658 def forward (self , x : torch .Tensor | float , y : torch .Tensor | float ) -> torch .Tensor :
4759 return torch .remainder (x , y )
4860
4961
50- def _get_aten_op (test_data : Remainder .input_t ):
51- if any (isinstance (x , float ) for x in test_data ):
52- return "torch.ops.aten.remainder.Scalar"
53- else :
54- return "torch.ops.aten.remainder.Tensor"
55-
56-
57- def _get_exir_op (test_data : Remainder .input_t ):
58- if isinstance (test_data [1 ], float ):
59- return "executorch_exir_dialects_edge__ops_aten_remainder_Scalar"
60- else :
61- return "executorch_exir_dialects_edge__ops_aten_remainder_Tensor"
62+ @common .parametrize ("test_data" , Remainder .test_cases_tensor )
63+ def test_remainder_tensor_tosa_FP (test_data ):
64+ data = test_data ()
65+ pipeline = TosaPipelineFP [Remainder .input_t ](
66+ Remainder (),
67+ data ,
68+ Remainder .aten_op_tensor ,
69+ Remainder .exir_op_tensor ,
70+ )
71+ pipeline .run ()
6272
6373
64- @common .parametrize ("test_data" , Remainder .test_cases )
65- def test_remainder_tosa_FP (test_data ):
74+ @common .parametrize ("test_data" , Remainder .test_cases_scalar )
75+ def test_remainder_scalar_tosa_FP (test_data ):
6676 data = test_data ()
6777 pipeline = TosaPipelineFP [Remainder .input_t ](
6878 Remainder (),
6979 data ,
70- _get_aten_op ( data ) ,
71- _get_exir_op ( data ) ,
80+ Remainder . aten_op_scalar ,
81+ Remainder . exir_op_scalar ,
7282 )
7383 pipeline .run ()
7484
7585
76- @common .parametrize ("test_data" , Remainder .test_cases )
77- def test_remainder_tosa_INT (test_data ):
86+ @common .parametrize ("test_data" , Remainder .test_cases_tensor )
87+ def test_remainder_tensor_tosa_INT (test_data ):
7888 pipeline = TosaPipelineINT [Remainder .input_t ](
7989 Remainder (),
8090 test_data (),
@@ -83,9 +93,30 @@ def test_remainder_tosa_INT(test_data):
8393 pipeline .run ()
8494
8595
86- @common .parametrize ("test_data" , Remainder .test_cases )
96+ @common .parametrize ("test_data" , Remainder .test_cases_scalar )
97+ def test_remainder_scalar_tosa_INT (test_data ):
98+ pipeline = TosaPipelineINT [Remainder .input_t ](
99+ Remainder (),
100+ test_data (),
101+ [],
102+ )
103+ pipeline .run ()
104+
105+
106+ @common .parametrize ("test_data" , Remainder .test_cases_tensor )
107+ @common .XfailIfNoCorstone300
108+ def test_remainder_tensor_u55_INT (test_data ):
109+ pipeline = EthosU55PipelineINT [Remainder .input_t ](
110+ Remainder (),
111+ test_data (),
112+ [],
113+ )
114+ pipeline .run ()
115+
116+
117+ @common .parametrize ("test_data" , Remainder .test_cases_scalar )
87118@common .XfailIfNoCorstone300
88- def test_remainder_u55_INT (test_data ):
119+ def test_remainder_scalar_u55_INT (test_data ):
89120 pipeline = EthosU55PipelineINT [Remainder .input_t ](
90121 Remainder (),
91122 test_data (),
@@ -94,9 +125,9 @@ def test_remainder_u55_INT(test_data):
94125 pipeline .run ()
95126
96127
97- @common .parametrize ("test_data" , Remainder .test_cases )
128+ @common .parametrize ("test_data" , Remainder .test_cases_tensor )
98129@common .XfailIfNoCorstone320
99- def test_remainder_u85_INT (test_data ):
130+ def test_remainder_tensor_u85_INT (test_data ):
100131 pipeline = EthosU85PipelineINT [Remainder .input_t ](
101132 Remainder (),
102133 test_data (),
@@ -105,23 +136,60 @@ def test_remainder_u85_INT(test_data):
105136 pipeline .run ()
106137
107138
108- @common .parametrize ("test_data" , Remainder .test_cases )
139+ @common .parametrize ("test_data" , Remainder .test_cases_scalar )
140+ @common .XfailIfNoCorstone320
141+ def test_remainder_scalar_u85_INT (test_data ):
142+ pipeline = EthosU85PipelineINT [Remainder .input_t ](
143+ Remainder (),
144+ test_data (),
145+ [],
146+ )
147+ pipeline .run ()
148+
149+
150+ @common .parametrize ("test_data" , Remainder .test_cases_tensor )
109151@common .SkipIfNoModelConverter
110- def test_remainder_vgf_FP (test_data ):
152+ def test_remainder_tensor_vgf_FP (test_data ):
111153 data = test_data ()
112154 pipeline = VgfPipeline [Remainder .input_t ](
113155 Remainder (),
114156 data ,
115- _get_aten_op ( data ) ,
116- _get_exir_op ( data ) ,
157+ Remainder . aten_op_tensor ,
158+ Remainder . exir_op_tensor ,
117159 tosa_version = "TOSA-1.0+FP" ,
118160 )
119161 pipeline .run ()
120162
121163
122- @common .parametrize ("test_data" , Remainder .test_cases )
164+ @common .parametrize ("test_data" , Remainder .test_cases_scalar )
165+ @common .SkipIfNoModelConverter
166+ def test_remainder_scalar_vgf_FP (test_data ):
167+ data = test_data ()
168+ pipeline = VgfPipeline [Remainder .input_t ](
169+ Remainder (),
170+ data ,
171+ Remainder .aten_op_scalar ,
172+ Remainder .exir_op_scalar ,
173+ tosa_version = "TOSA-1.0+FP" ,
174+ )
175+ pipeline .run ()
176+
177+
178+ @common .parametrize ("test_data" , Remainder .test_cases_tensor )
179+ @common .SkipIfNoModelConverter
180+ def test_remainder_tensor_vgf_INT (test_data ):
181+ pipeline = VgfPipeline [Remainder .input_t ](
182+ Remainder (),
183+ test_data (),
184+ [],
185+ tosa_version = "TOSA-1.0+INT" ,
186+ )
187+ pipeline .run ()
188+
189+
190+ @common .parametrize ("test_data" , Remainder .test_cases_scalar )
123191@common .SkipIfNoModelConverter
124- def test_remainder_vgf_INT (test_data ):
192+ def test_remainder_scalar_vgf_INT (test_data ):
125193 pipeline = VgfPipeline [Remainder .input_t ](
126194 Remainder (),
127195 test_data (),
0 commit comments