Skip to content

Commit 421072e

Browse files
committed
enable test_small_batch_matmul
1 parent 566e6ee commit 421072e

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,9 @@ def round_x(x, idx):
476476
@pytest.mark.parametrize("m", [8, 16, 32, 64, 128])
477477
@pytest.mark.parametrize("n", [8, 16, 32, 64, 128])
478478
@pytest.mark.parametrize("k", [8, 16, 32, 64, 128])
479-
def test_small_batch_matmul(m, n, k):
479+
def test_small_batch_matmul(m, n, k, device):
480480
if is_hip():
481481
pytest.skip("Not fully tested on AMD")
482-
if is_xpu():
483-
pytest.xfail("Enable: https://github.com/intel/intel-xpu-backend-for-triton/issues/5092")
484482

485483
if m * n * k > 16384:
486484
pytest.skip()
@@ -490,7 +488,7 @@ def test_small_batch_matmul(m, n, k):
490488
def _make_tensor(shape, dtype, trans):
491489
if trans:
492490
shape = (shape[0], shape[2], shape[1])
493-
t = alloc_rand(shape, "cuda", dtype)
491+
t = alloc_rand(shape, device, dtype)
494492
return t.transpose(1, 2) if trans else t
495493

496494
for x_transpose, w_transpose, bias, dtype in itertools.product(
@@ -499,7 +497,7 @@ def _make_tensor(shape, dtype, trans):
499497
(False, True),
500498
(torch.float16, torch.bfloat16, torch.float8_e5m2),
501499
):
502-
if (
500+
if device == "cuda" and (
503501
torch.cuda.get_device_capability()[0] < 10
504502
and dtype is torch.float8_e5m2
505503
and (not w_transpose)

0 commit comments

Comments
 (0)