Skip to content

Commit eab348f

Browse files
dev-tomekanmyachev
andauthored
Enable test small batch matmul test (#5154)
This PR enables small batch matmul test on XPU. Reported in #5092 --------- Co-authored-by: Anatoly Myachev <[email protected]>
1 parent 22725d3 commit eab348f

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# testing utilities
2222
from triton_kernels.testing import assert_close, compute_actual_scale
2323
# target-specific utilities
24-
from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4
24+
from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4
2525

2626
# ---------------
2727
# initialize data
@@ -507,11 +507,9 @@ def round_x(x, idx):
507507
@pytest.mark.parametrize("m", [8, 16, 32, 64, 128])
508508
@pytest.mark.parametrize("n", [8, 16, 32, 64, 128])
509509
@pytest.mark.parametrize("k", [8, 16, 32, 64, 128])
510-
def test_small_batch_matmul(m, n, k):
510+
def test_small_batch_matmul(m, n, k, device):
511511
if is_hip():
512512
pytest.skip("Not fully tested on AMD")
513-
if is_xpu():
514-
pytest.xfail("Enable: https://github.com/intel/intel-xpu-backend-for-triton/issues/5092")
515513

516514
if m * n * k > 16384:
517515
pytest.skip()
@@ -521,7 +519,7 @@ def test_small_batch_matmul(m, n, k):
521519
def _make_tensor(shape, dtype, trans):
522520
if trans:
523521
shape = (shape[0], shape[2], shape[1])
524-
t = alloc_rand(shape, "cuda", dtype)
522+
t = alloc_rand(shape, device, dtype)
525523
return t.transpose(1, 2) if trans else t
526524

527525
for x_transpose, w_transpose, bias, dtype in itertools.product(
@@ -530,7 +528,7 @@ def _make_tensor(shape, dtype, trans):
530528
(False, True),
531529
(torch.float16, torch.bfloat16, torch.float8_e5m2),
532530
):
533-
if (
531+
if is_cuda() and (
534532
torch.cuda.get_device_capability()[0] < 10
535533
and dtype is torch.float8_e5m2
536534
and (not w_transpose)

0 commit comments

Comments
 (0)