Skip to content

Commit 991152f

Browse files
authored
[Hopper][Warp Spec] Enable some WS tests for Hopper (#7623)
1 parent cf9b4ea commit 991152f

File tree

1 file changed

+25
-13
lines changed

1 file changed

+25
-13
lines changed

python/test/unit/language/test_warp_specialization.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@
44
import triton
55
import triton.language as tl
66

7-
from triton._internal_testing import is_hip
7+
from triton._internal_testing import is_hip, is_hopper, is_blackwell
88
from triton.tools.tensor_descriptor import TensorDescriptor
99

10-
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 10:
10+
if not is_hip() and torch.cuda.is_available() and torch.cuda.get_device_capability()[0] in [9, 10]:
1111
from triton._C.libtriton import nvidia
1212
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
1313
cublas = nvidia.cublas.CublasLt(cublas_workspace)
1414
else:
1515
cublas = None
1616

1717

18+
def is_hopper_or_blackwell():
19+
return is_hopper() or is_blackwell()
20+
21+
1822
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
19-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
23+
@pytest.mark.skipif(not is_hopper_or_blackwell(), reason="Requires Hopper or Blackwell")
2024
def test_warp_specialize_basic_ir(tmp_path: pathlib.Path):
2125
ir = """
2226
tt.func @kernel(%arg0: !tt.ptr<i32>) {
@@ -51,7 +55,7 @@ def test_warp_specialize_basic_ir(tmp_path: pathlib.Path):
5155

5256

5357
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
54-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
58+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
5559
def test_warp_specialize_tmem_ir(tmp_path: pathlib.Path):
5660
ir = """
5761
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
@@ -119,7 +123,7 @@ def test_warp_specialize_tmem_ir(tmp_path: pathlib.Path):
119123

120124

121125
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
122-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
126+
@pytest.mark.skipif(not is_hopper_or_blackwell(), reason="Requires Hopper or Blackwell")
123127
def test_warpgroup_reduction(tmp_path: pathlib.Path):
124128

125129
def template(i, num_warps, in_ptr, out_ptr):
@@ -242,11 +246,11 @@ def exceeds_smem_capacity(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, use_fp8):
242246
@pytest.mark.parametrize("BLOCK_SIZE_M", [128])
243247
@pytest.mark.parametrize("BLOCK_SIZE_N", [128, 256])
244248
@pytest.mark.parametrize("BLOCK_SIZE_K", [64, 128])
245-
@pytest.mark.parametrize("num_stages", [2, 3, 4])
249+
@pytest.mark.parametrize("num_stages", [2, 3])
246250
@pytest.mark.parametrize("num_warps", [4, 8])
247251
@pytest.mark.parametrize("use_fp8", [False, True])
248252
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
249-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
253+
@pytest.mark.skipif(not is_hopper_or_blackwell(), reason="Requires Hopper or Blackwell")
250254
def test_warp_specialize_tma_matmul(M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, num_stages, num_warps, use_fp8):
251255
if exceeds_smem_capacity(num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, use_fp8=use_fp8):
252256
pytest.skip("uses too much shared memory")
@@ -270,8 +274,16 @@ def alloc_fn(size, align, stream):
270274
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M, num_warps=num_warps,
271275
USE_FP8=use_fp8)
272276
ttgir = kernel.asm["ttgir"]
273-
assert "ttng.tc_gen5_mma" in ttgir
274-
assert "ttg.warp_specialize" in ttgir
277+
if is_blackwell():
278+
assert "ttng.tc_gen5_mma" in ttgir
279+
assert "ttng.async_tma_copy_global_to_local" in ttgir
280+
else:
281+
assert "ttng.warp_group_dot" in ttgir
282+
assert "ttng.async_tma_copy_global_to_local" in ttgir
283+
if is_hopper() and num_warps == 8:
284+
assert "ttg.warp_specialize" not in ttgir
285+
else:
286+
assert "ttg.warp_specialize" in ttgir
275287

276288
ref_out = torch.empty((M, N), dtype=dtype, device=device)
277289
cublas.matmul(A, B, ref_out)
@@ -326,11 +338,11 @@ def matmul_tma_persistent_ws_kernel( #
326338
@pytest.mark.parametrize("BLOCK_SIZE_M", [128])
327339
@pytest.mark.parametrize("BLOCK_SIZE_N", [128, 256])
328340
@pytest.mark.parametrize("BLOCK_SIZE_K", [64, 128])
329-
@pytest.mark.parametrize("num_stages", [2, 3, 4])
341+
@pytest.mark.parametrize("num_stages", [2, 3])
330342
@pytest.mark.parametrize("num_warps", [4, 8])
331343
@pytest.mark.parametrize("use_fp8", [False, True])
332344
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
333-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
345+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
334346
def test_warp_specialize_tma_matmul_persistent(M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, num_stages, num_warps,
335347
use_fp8):
336348
if exceeds_smem_capacity(num_stages, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, use_fp8):
@@ -413,12 +425,12 @@ def attention_inner_loop_kernel( #
413425
@pytest.mark.parametrize("M, N", [(8192, 8192), (1024, 1024)])
414426
@pytest.mark.parametrize("BLOCK_M", [64, 128])
415427
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
416-
@pytest.mark.parametrize("num_stages", [2, 3, 4])
428+
@pytest.mark.parametrize("num_stages", [2, 3])
417429
@pytest.mark.parametrize("disable_acc_multibuf", [False, True])
418430
@pytest.mark.parametrize("num_warps", [4, 8])
419431
@pytest.mark.parametrize("use_fp8", [False, True])
420432
@pytest.mark.skipif(is_hip(), reason="warp specialization is not supported on hip devices")
421-
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] != 10, reason="Requires compute capability == 10")
433+
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
422434
def test_warp_specialize_attention_forward(M, N, BLOCK_M, HEAD_DIM, num_stages, disable_acc_multibuf, num_warps,
423435
use_fp8):
424436
if BLOCK_M == 128 and HEAD_DIM == 128 and not use_fp8:

0 commit comments

Comments
 (0)