@@ -1337,7 +1337,9 @@ def torch_gather_rows(input, idx, y, block_y):
1337
1337
reason = "TMA Gather not supported on Hopper" )
1338
1338
def test_tma_gather (X , Y , BLOCK_X , BLOCK_Y , dtype , y , device ):
1339
1339
if BLOCK_X > X or y + BLOCK_Y > Y :
1340
- pytest .skip ()
1340
+ pytest .xfail ()
1341
+ if is_xpu ():
1342
+ pytest .skip ("FIXME: issue #4267" )
1341
1343
1342
1344
torch .manual_seed (42 )
1343
1345
if dtype != torch .int8 :
@@ -1389,6 +1391,8 @@ def tma_gather_dot_pipeline( #
1389
1391
@pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 ,
1390
1392
reason = "TMA Gather not supported on hopper" )
1391
1393
def test_tma_gather_dot_pipeline (BLOCK_M , BLOCK_N , BLOCK_K , K , device ):
1394
+ if is_xpu ():
1395
+ pytest .skip ("FIXME: issue #4267" )
1392
1396
1393
1397
def alloc_fn (size : int , align : int , steam ):
1394
1398
return torch .empty (size , dtype = torch .int8 , device = device )
@@ -1436,18 +1440,20 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.
1436
1440
@pytest .mark .parametrize ("y" , [0 , 32 , 48 ])
1437
1441
@pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 ,
1438
1442
reason = "TMA Scatter not supported on hopper" )
1439
- def test_tma_scatter (X , Y , BLOCK_X , BLOCK_Y , dtype , y ):
1443
+ def test_tma_scatter (X , Y , BLOCK_X , BLOCK_Y , dtype , y , device ):
1440
1444
if BLOCK_X > X or y + BLOCK_Y > Y :
1441
- pytest .skip ()
1445
+ pytest .xfail ()
1446
+ if is_xpu ():
1447
+ pytest .skip ("FIXME: issue #4267" )
1442
1448
1443
1449
torch .manual_seed (42 )
1444
- input = torch .arange (BLOCK_X * BLOCK_Y , dtype = dtype , device = 'cuda' ).reshape (BLOCK_X , BLOCK_Y )
1445
- output = torch .zeros ((X , Y ), dtype = dtype , device = 'cuda' )
1450
+ input = torch .arange (BLOCK_X * BLOCK_Y , dtype = dtype , device = device ).reshape (BLOCK_X , BLOCK_Y )
1451
+ output = torch .zeros ((X , Y ), dtype = dtype , device = device )
1446
1452
1447
- idx = torch .randperm (BLOCK_X , dtype = torch .int32 , device = 'cuda' )
1453
+ idx = torch .randperm (BLOCK_X , dtype = torch .int32 , device = device )
1448
1454
1449
1455
def alloc_fn (size : int , align : int , steam ):
1450
- return torch .empty (size , dtype = torch .int8 , device = 'cuda' )
1456
+ return torch .empty (size , dtype = torch .int8 , device = device )
1451
1457
1452
1458
triton .set_allocator (alloc_fn )
1453
1459
0 commit comments