@@ -1337,7 +1337,9 @@ def torch_gather_rows(input, idx, y, block_y):
13371337 reason = "TMA Gather not supported on Hopper" )
13381338def test_tma_gather (X , Y , BLOCK_X , BLOCK_Y , dtype , y , device ):
13391339 if BLOCK_X > X or y + BLOCK_Y > Y :
1340- pytest .skip ()
1340+ pytest .xfail ()
1341+ if is_xpu ():
1342+ pytest .skip ("FIXME: issue #4267" )
13411343
13421344 torch .manual_seed (42 )
13431345 if dtype != torch .int8 :
@@ -1389,6 +1391,8 @@ def tma_gather_dot_pipeline( #
13891391@pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 ,
13901392 reason = "TMA Gather not supported on hopper" )
13911393def test_tma_gather_dot_pipeline (BLOCK_M , BLOCK_N , BLOCK_K , K , device ):
1394+ if is_xpu ():
1395+ pytest .skip ("FIXME: issue #4267" )
13921396
13931397 def alloc_fn (size : int , align : int , steam ):
13941398 return torch .empty (size , dtype = torch .int8 , device = device )
@@ -1436,18 +1440,20 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.
14361440@pytest .mark .parametrize ("y" , [0 , 32 , 48 ])
14371441@pytest .mark .skipif (is_cuda () and torch .cuda .get_device_capability ()[0 ] == 9 ,
14381442 reason = "TMA Scatter not supported on hopper" )
1439- def test_tma_scatter (X , Y , BLOCK_X , BLOCK_Y , dtype , y ):
1443+ def test_tma_scatter (X , Y , BLOCK_X , BLOCK_Y , dtype , y , device ):
14401444 if BLOCK_X > X or y + BLOCK_Y > Y :
1441- pytest .skip ()
1445+ pytest .xfail ()
1446+ if is_xpu ():
1447+ pytest .skip ("FIXME: issue #4267" )
14421448
14431449 torch .manual_seed (42 )
1444- input = torch .arange (BLOCK_X * BLOCK_Y , dtype = dtype , device = 'cuda' ).reshape (BLOCK_X , BLOCK_Y )
1445- output = torch .zeros ((X , Y ), dtype = dtype , device = 'cuda' )
1450+ input = torch .arange (BLOCK_X * BLOCK_Y , dtype = dtype , device = device ).reshape (BLOCK_X , BLOCK_Y )
1451+ output = torch .zeros ((X , Y ), dtype = dtype , device = device )
14461452
1447- idx = torch .randperm (BLOCK_X , dtype = torch .int32 , device = 'cuda' )
1453+ idx = torch .randperm (BLOCK_X , dtype = torch .int32 , device = device )
14481454
14491455 def alloc_fn (size : int , align : int , steam ):
1450- return torch .empty (size , dtype = torch .int8 , device = 'cuda' )
1456+ return torch .empty (size , dtype = torch .int8 , device = device )
14511457
14521458 triton .set_allocator (alloc_fn )
14531459
0 commit comments