@@ -476,11 +476,9 @@ def round_x(x, idx):
476476@pytest .mark .parametrize ("m" , [8 , 16 , 32 , 64 , 128 ])
477477@pytest .mark .parametrize ("n" , [8 , 16 , 32 , 64 , 128 ])
478478@pytest .mark .parametrize ("k" , [8 , 16 , 32 , 64 , 128 ])
479- def test_small_batch_matmul (m , n , k ):
479+ def test_small_batch_matmul (m , n , k , device ):
480480 if is_hip ():
481481 pytest .skip ("Not fully tested on AMD" )
482- if is_xpu ():
483- pytest .xfail ("Enable: https://github.com/intel/intel-xpu-backend-for-triton/issues/5092" )
484482
485483 if m * n * k > 16384 :
486484 pytest .skip ()
@@ -490,7 +488,7 @@ def test_small_batch_matmul(m, n, k):
490488 def _make_tensor (shape , dtype , trans ):
491489 if trans :
492490 shape = (shape [0 ], shape [2 ], shape [1 ])
493- t = alloc_rand (shape , "cuda" , dtype )
491+ t = alloc_rand (shape , device , dtype )
494492 return t .transpose (1 , 2 ) if trans else t
495493
496494 for x_transpose , w_transpose , bias , dtype in itertools .product (
@@ -499,7 +497,7 @@ def _make_tensor(shape, dtype, trans):
499497 (False , True ),
500498 (torch .float16 , torch .bfloat16 , torch .float8_e5m2 ),
501499 ):
502- if (
500+ if device == "cuda" and (
503501 torch .cuda .get_device_capability ()[0 ] < 10
504502 and dtype is torch .float8_e5m2
505503 and (not w_transpose )
0 commit comments