2222
2323
2424class Split (torch .nn .Module ):
25-
2625 test_data = {
2726 "split_1d_2_size_0_dim" : lambda : (torch .rand (10 ), 2 , 0 ),
2827 "split_2d_3_size_1_dim" : lambda : (torch .rand (10 , 10 ), 3 , 1 ),
@@ -60,12 +59,24 @@ def forward(
6059 return x .split (split_size = split_size_or_sections , dim = dim )[1 :3 ]
6160
6261
62+ class SplitCopy (torch .nn .Module ):
63+ aten_op = "torch.ops.aten.split_copy.Tensor"
64+ exir_op = "executorch_exir_dialects_edge__ops_aten_split_copy_Tensor"
65+
66+ def forward (
67+ self ,
68+ x : torch .Tensor ,
69+ split_size : int ,
70+ dim : int ,
71+ ):
72+ return torch .split_copy (x , split_size = split_size , dim = dim )
73+
74+
6375@common .parametrize (
6476 "test_data" ,
6577 (Split .test_data | Split .test_data_list ),
6678)
6779def test_split_with_sizes_tosa_FP (test_data : input_t1 ):
68-
6980 pipeline = TosaPipelineFP [input_t1 ](
7081 Split (),
7182 test_data (),
@@ -77,7 +88,6 @@ def test_split_with_sizes_tosa_FP(test_data: input_t1):
7788
7889@common .parametrize ("test_data" , Split .test_data_list )
7990def test_split_with_sizes_tosa_FP_2 (test_data : input_t1 ):
80-
8191 pipeline = TosaPipelineFP [input_t1 ](
8292 SplitWithSizes (),
8393 test_data (),
@@ -92,7 +102,6 @@ def test_split_with_sizes_tosa_FP_2(test_data: input_t1):
92102 (Split .test_data | Split .test_data_list ),
93103)
94104def test_split_with_sizes_tosa_FP_one_out (test_data : input_t1 ):
95-
96105 pipeline = TosaPipelineFP [input_t1 ](
97106 SplitSingleOut (),
98107 test_data (),
@@ -107,7 +116,6 @@ def test_split_with_sizes_tosa_FP_one_out(test_data: input_t1):
107116 (Split .test_data | Split .test_data_list ),
108117)
109118def test_split_with_sizes_tosa_FP_two_out (test_data : input_t1 ):
110-
111119 pipeline = TosaPipelineFP [input_t1 ](
112120 SplitTwoOut (),
113121 test_data (),
@@ -122,7 +130,6 @@ def test_split_with_sizes_tosa_FP_two_out(test_data: input_t1):
122130 (Split .test_data | Split .test_data_list ),
123131)
124132def test_split_with_sizes_tosa_INT (test_data : input_t1 ):
125-
126133 pipeline = TosaPipelineINT [input_t1 ](
127134 Split (),
128135 test_data (),
@@ -152,7 +159,6 @@ def test_split_with_sizes_u55_INT(test_data: input_t1):
152159 (Split .test_data | Split .test_data_list ),
153160)
154161def test_split_with_sizes_u85_INT (test_data : input_t1 ):
155-
156162 pipeline = EthosU85PipelineINT [input_t1 ](
157163 Split (),
158164 test_data (),
@@ -182,7 +188,6 @@ def test_split_with_sizes_vgf_FP(test_data: input_t1):
182188@common .parametrize ("test_data" , Split .test_data_list )
183189@common .SkipIfNoModelConverter
184190def test_split_with_sizes_vgf_FP_2 (test_data : input_t1 ):
185-
186191 pipeline = VgfPipeline [input_t1 ](
187192 SplitWithSizes (),
188193 test_data (),
@@ -199,7 +204,6 @@ def test_split_with_sizes_vgf_FP_2(test_data: input_t1):
199204)
200205@common .SkipIfNoModelConverter
201206def test_split_with_sizes_vgf_FP_one_out (test_data : input_t1 ):
202-
203207 pipeline = VgfPipeline [input_t1 ](
204208 SplitSingleOut (),
205209 test_data (),
@@ -216,7 +220,6 @@ def test_split_with_sizes_vgf_FP_one_out(test_data: input_t1):
216220)
217221@common .SkipIfNoModelConverter
218222def test_split_with_sizes_vgf_FP_two_out (test_data : input_t1 ):
219-
220223 pipeline = VgfPipeline [input_t1 ](
221224 SplitTwoOut (),
222225 test_data (),
@@ -233,7 +236,6 @@ def test_split_with_sizes_vgf_FP_two_out(test_data: input_t1):
233236)
234237@common .SkipIfNoModelConverter
235238def test_split_with_sizes_vgf_INT (test_data : input_t1 ):
236-
237239 pipeline = VgfPipeline [input_t1 ](
238240 Split (),
239241 test_data (),
@@ -242,3 +244,75 @@ def test_split_with_sizes_vgf_INT(test_data: input_t1):
242244 tosa_version = "TOSA-1.0+INT" ,
243245 )
244246 pipeline .run ()
247+
248+
249+ @common .parametrize ("test_data" , Split .test_data )
250+ def test_split_tensor_tosa_FP (test_data : Tuple ):
251+ pipeline = TosaPipelineFP [input_t1 ](
252+ SplitCopy (),
253+ test_data (),
254+ aten_op = SplitCopy .aten_op ,
255+ exir_op = SplitCopy .exir_op ,
256+ )
257+ pipeline .run ()
258+
259+
260+ @common .parametrize ("test_data" , Split .test_data )
261+ def test_split_tensor_tosa_INT (test_data : Tuple ):
262+ pipeline = TosaPipelineINT [input_t1 ](
263+ SplitCopy (),
264+ test_data (),
265+ aten_op = SplitCopy .aten_op ,
266+ exir_op = SplitCopy .exir_op ,
267+ )
268+ pipeline .run ()
269+
270+
271+ @common .XfailIfNoCorstone300
272+ @common .parametrize ("test_data" , Split .test_data )
273+ def test_split_tensor_u55_INT (test_data : Tuple ):
274+ pipeline = EthosU55PipelineINT [input_t1 ](
275+ SplitCopy (),
276+ test_data (),
277+ aten_ops = SplitCopy .aten_op ,
278+ exir_ops = SplitCopy .exir_op ,
279+ )
280+ pipeline .run ()
281+
282+
283+ @common .XfailIfNoCorstone320
284+ @common .parametrize ("test_data" , Split .test_data )
285+ def test_split_tensor_u85_INT (test_data : Tuple ):
286+ pipeline = EthosU85PipelineINT [input_t1 ](
287+ SplitCopy (),
288+ test_data (),
289+ aten_ops = SplitCopy .aten_op ,
290+ exir_ops = SplitCopy .exir_op ,
291+ )
292+ pipeline .run ()
293+
294+
295+ @common .parametrize ("test_data" , Split .test_data )
296+ @common .SkipIfNoModelConverter
297+ def test_split_tensor_vgf_FP (test_data : Tuple ):
298+ pipeline = VgfPipeline [input_t1 ](
299+ SplitCopy (),
300+ test_data (),
301+ aten_op = SplitCopy .aten_op ,
302+ exir_op = SplitCopy .exir_op ,
303+ tosa_version = "TOSA-1.0+FP" ,
304+ )
305+ pipeline .run ()
306+
307+
308+ @common .parametrize ("test_data" , Split .test_data )
309+ @common .SkipIfNoModelConverter
310+ def test_split_tensor_vgf_INT (test_data : Tuple ):
311+ pipeline = VgfPipeline [input_t1 ](
312+ SplitCopy (),
313+ test_data (),
314+ aten_op = SplitCopy .aten_op ,
315+ exir_op = SplitCopy .exir_op ,
316+ tosa_version = "TOSA-1.0+INT" ,
317+ )
318+ pipeline .run ()
0 commit comments