|
1 | 1 | # isort: off |
2 | 2 | # fmt: off |
3 | 3 | from dataclasses import dataclass, fields, replace |
| 4 | +import itertools |
4 | 5 | import pytest |
5 | 6 | import torch |
6 | 7 | from typing import Union |
|
20 | 21 | # testing utilities |
21 | 22 | from triton_kernels.testing import assert_close, compute_actual_scale |
22 | 23 | # target-specific utilities |
23 | | -from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 |
| 24 | +from triton_kernels.target_info import is_hip, is_xpu, is_hip_cdna3, is_cuda, is_hip_cdna4 |
24 | 25 |
|
25 | 26 | # --------------- |
26 | 27 | # initialize data |
@@ -471,14 +472,68 @@ def round_x(x, idx): |
471 | 472 | tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" |
472 | 473 |
|
473 | 474 |
|
| 475 | +# Test that we don't use unsupported block sizes. |
| 476 | +@pytest.mark.parametrize("m", [8, 16, 32, 64, 128]) |
| 477 | +@pytest.mark.parametrize("n", [8, 16, 32, 64, 128]) |
| 478 | +@pytest.mark.parametrize("k", [8, 16, 32, 64, 128]) |
| 479 | +def test_small_batch_matmul(m, n, k): |
| 480 | + if is_hip(): |
| 481 | + pytest.skip("Not fully tested on AMD") |
| 482 | + if is_xpu(): |
| 483 | + pytest.xfail("Enable: https://github.com/intel/intel-xpu-backend-for-triton/issues/5092") |
| 484 | + |
| 485 | + if m * n * k > 16384: |
| 486 | + pytest.skip() |
| 487 | + |
| 488 | + BATCH_SIZE = 10000 |
| 489 | + |
| 490 | + def _make_tensor(shape, dtype, trans): |
| 491 | + if trans: |
| 492 | + shape = (shape[0], shape[2], shape[1]) |
| 493 | + t = alloc_rand(shape, "cuda", dtype) |
| 494 | + return t.transpose(1, 2) if trans else t |
| 495 | + |
| 496 | + for x_transpose, w_transpose, bias, dtype in itertools.product( |
| 497 | + (False, True), |
| 498 | + (False, True), |
| 499 | + (False, True), |
| 500 | + (torch.float16, torch.bfloat16, torch.float8_e5m2), |
| 501 | + ): |
| 502 | + if ( |
| 503 | + torch.cuda.get_device_capability()[0] < 10 |
| 504 | + and dtype is torch.float8_e5m2 |
| 505 | + and (not w_transpose) |
| 506 | + ): |
| 507 | + continue # Not supported |
| 508 | + |
| 509 | + x = _make_tensor((BATCH_SIZE, m, k), dtype, x_transpose) |
| 510 | + w = _make_tensor((BATCH_SIZE, k, n), dtype, w_transpose) |
| 511 | + bias = _make_tensor((BATCH_SIZE, n), torch.float32, False) if bias else None |
| 512 | + tri_y = matmul_ogs(x, w, bias) |
| 513 | + |
| 514 | + # ref_y = matmul_ogs_torch(x.float(), w.float(), bias) |
| 515 | + |
| 516 | + # This is faster than matmul_ogs_torch. |
| 517 | + ref_y = torch.bmm(x.float(), w.float()) |
| 518 | + if bias is not None: |
| 519 | + ref_y += bias[:, None, :] |
| 520 | + |
| 521 | + assert_close( |
| 522 | + ref_y, |
| 523 | + tri_y, |
| 524 | + maxtol=4e-1 if dtype is torch.float8_e5m2 else None, |
| 525 | + rmstol=4e-2 if dtype is torch.float8_e5m2 else None, |
| 526 | + ) |
| 527 | + |
| 528 | + |
474 | 529 | def test_set_idle_sms(): |
475 | 530 | if not is_cuda(): |
476 | 531 | pytest.skip("Only supported on CUDA") |
477 | 532 | from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags |
478 | 533 | num_idle_sms = 24 |
479 | 534 | matmul_ogs_set_idle_sms(num_idle_sms) |
480 | 535 | flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ |
481 | | - 1024, 1024, 1024, None, True, False, 1) |
| 536 | + 1, 1024, 1024, 1024, None, True, False, 1) |
482 | 537 | assert flags.idle_sms == num_idle_sms |
483 | 538 |
|
484 | 539 |
|
|
0 commit comments