99import torch
1010from executorch .backends .arm .test import common
1111from executorch .backends .arm .test .tester .test_pipeline import (
12+ EthosU55PipelineINT ,
13+ EthosU85PipelineINT ,
1214 TosaPipelineFP ,
1315 TosaPipelineINT ,
1416 VgfPipeline ,
@@ -30,8 +32,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3032 return torch .unflatten (x , self .dim , self .sizes )
3133
3234 test_data : dict [str , test_data_t ] = {
33- "randn_4d" : (lambda : (Unflatten (1 , (2 , 2 )), (torch .randn (3 , 4 , 5 , 1 ),))),
34- "rand_3d" : (lambda : (Unflatten (1 , (- 1 , 2 )), (torch .rand (3 , 4 , 4 ),))),
35+ "rand_3d_batch3" : (lambda : (Unflatten (1 , (- 1 , 2 )), (torch .rand (3 , 4 , 4 ),))),
36+ "rand_3d_batch1" : (lambda : (Unflatten (1 , (- 1 , 2 )), (torch .rand (1 , 4 , 4 ),))),
37+ "randn_4d_dim1" : (lambda : (Unflatten (1 , (2 , 2 )), (torch .randn (3 , 4 , 5 , 1 ),))),
38+ "randn_4d_dim3" : (lambda : (Unflatten (3 , (2 , 2 )), (torch .randn (1 , 1 , 5 , 4 ),))),
3539 }
3640
3741
@@ -49,7 +53,33 @@ def test_unflatten_int_tosa_FP(test_data: test_data_t):
4953@common .parametrize ("test_data" , Unflatten .test_data )
5054def test_unflatten_int_tosa_INT (test_data : test_data_t ):
5155 module , inputs = test_data ()
52- pipeline = TosaPipelineINT [input_t ](
56+ pipeline = TosaPipelineINT [input_t ](module , inputs , Unflatten .aten_op )
57+ pipeline .run ()
58+
59+
60+ xfails = {
61+ "rand_3d_batch3" : "Batch size > 1 currently not supported for FVP tests" ,
62+ "randn_4d_dim1" : "Batch size > 1 currently not supported for FVP tests" ,
63+ }
64+
65+
66+ @common .parametrize ("test_data" , Unflatten .test_data , xfails = xfails , strict = False )
67+ @common .XfailIfNoCorstone300
68+ def test_unflatten_int_u55_INT (test_data : test_data_t ):
69+ module , inputs = test_data ()
70+ pipeline = EthosU55PipelineINT [input_t ](
71+ module ,
72+ inputs ,
73+ Unflatten .aten_op ,
74+ )
75+ pipeline .run ()
76+
77+
78+ @common .parametrize ("test_data" , Unflatten .test_data , xfails = xfails , strict = False )
79+ @common .XfailIfNoCorstone320
80+ def test_unflatten_int_u85_INT (test_data : test_data_t ):
81+ module , inputs = test_data ()
82+ pipeline = EthosU85PipelineINT [input_t ](
5383 module ,
5484 inputs ,
5585 Unflatten .aten_op ,
0 commit comments