@@ -1777,20 +1777,6 @@ def forward(self, x):
17771777 (torch .rand (size = [1 , 5 , 2 , 3 ]),),
17781778 )
17791779
1780- def test_vulkan_backend_high_dim_tensors_fail (self ):
1781- class UnsqueezeHigherDim (torch .nn .Module ):
1782- def __init__ (self ):
1783- super ().__init__ ()
1784-
1785- def forward (self , x ):
1786- return torch .unsqueeze (x , 2 )
1787-
1788- self .lower_module_and_test_output (
1789- UnsqueezeHigherDim (),
1790- (torch .ones (size = [5 , 4 , 1 , 2 , 6 ]),),
1791- expect_no_delegates = True ,
1792- )
1793-
17941780 def test_vulkan_backend_large_linear_layer (self ):
17951781 class LinearModel (torch .nn .Module ):
17961782 def __init__ (self , large_out_channels : int ) -> None :
@@ -2298,6 +2284,28 @@ def forward(self, x1, x2, x3, x4, x5, x6):
22982284 test_inputs = test_inputs ,
22992285 )
23002286
2287+ def test_vulkan_backend_high_dimensional_tensors (self ):
2288+ class HighDimTensorModule (torch .nn .Module ):
2289+ def __init__ (self ):
2290+ super ().__init__ ()
2291+
2292+ def forward (self , x , y ):
2293+ # Unsqueeze inputs twice to create 5-dim tensors
2294+ x_5d = torch .unsqueeze (torch .unsqueeze (x , 0 ), 0 )
2295+ y_5d = torch .unsqueeze (torch .unsqueeze (y , 0 ), 0 )
2296+ # Add tensors together
2297+ result = x_5d + y_5d
2298+ return result
2299+
2300+ high_dim_module = HighDimTensorModule ()
2301+ # Create 2 4-dim inputs
2302+ sample_inputs = (
2303+ torch .rand (size = (2 , 3 , 4 , 5 ), dtype = torch .float32 ),
2304+ torch .rand (size = (2 , 3 , 4 , 5 ), dtype = torch .float32 ),
2305+ )
2306+
2307+ self .lower_module_and_test_output (high_dim_module , sample_inputs )
2308+
23012309 def test_vulkan_backend_torchao_wo_quantized_linear (self ):
23022310 in_features = 1024
23032311 out_features = 512
0 commit comments