4
4
import triton
5
5
import triton .language as tl
6
6
7
- from triton ._internal_testing import is_hip
7
+ from triton ._internal_testing import is_hip , is_hopper , is_blackwell
8
8
from triton .tools .tensor_descriptor import TensorDescriptor
9
9
10
- if torch .cuda .is_available () and torch .cuda .get_device_capability ()[0 ] == 10 :
10
+ if not is_hip () and torch .cuda .is_available () and torch .cuda .get_device_capability ()[0 ] in [ 9 , 10 ] :
11
11
from triton ._C .libtriton import nvidia
12
12
cublas_workspace = torch .empty (32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8 )
13
13
cublas = nvidia .cublas .CublasLt (cublas_workspace )
14
14
else :
15
15
cublas = None
16
16
17
17
18
+ def is_hopper_or_blackwell ():
19
+ return is_hopper () or is_blackwell ()
20
+
21
+
18
22
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
19
- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
23
+ @pytest .mark .skipif (not is_hopper_or_blackwell () , reason = "Requires Hopper or Blackwell " )
20
24
def test_warp_specialize_basic_ir (tmp_path : pathlib .Path ):
21
25
ir = """
22
26
tt.func @kernel(%arg0: !tt.ptr<i32>) {
@@ -51,7 +55,7 @@ def test_warp_specialize_basic_ir(tmp_path: pathlib.Path):
51
55
52
56
53
57
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
54
- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
58
+ @pytest .mark .skipif (not is_blackwell () , reason = "Requires Blackwell " )
55
59
def test_warp_specialize_tmem_ir (tmp_path : pathlib .Path ):
56
60
ir = """
57
61
#blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
@@ -119,7 +123,7 @@ def test_warp_specialize_tmem_ir(tmp_path: pathlib.Path):
119
123
120
124
121
125
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
122
- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
126
+ @pytest .mark .skipif (not is_hopper_or_blackwell () , reason = "Requires Hopper or Blackwell " )
123
127
def test_warpgroup_reduction (tmp_path : pathlib .Path ):
124
128
125
129
def template (i , num_warps , in_ptr , out_ptr ):
@@ -242,11 +246,11 @@ def exceeds_smem_capacity(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, use_fp8):
242
246
@pytest .mark .parametrize ("BLOCK_SIZE_M" , [128 ])
243
247
@pytest .mark .parametrize ("BLOCK_SIZE_N" , [128 , 256 ])
244
248
@pytest .mark .parametrize ("BLOCK_SIZE_K" , [64 , 128 ])
245
- @pytest .mark .parametrize ("num_stages" , [2 , 3 , 4 ])
249
+ @pytest .mark .parametrize ("num_stages" , [2 , 3 ])
246
250
@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
247
251
@pytest .mark .parametrize ("use_fp8" , [False , True ])
248
252
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
249
- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
253
+ @pytest .mark .skipif (not is_hopper_or_blackwell () , reason = "Requires Hopper or Blackwell " )
250
254
def test_warp_specialize_tma_matmul (M , N , K , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , num_stages , num_warps , use_fp8 ):
251
255
if exceeds_smem_capacity (num_stages , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , use_fp8 = use_fp8 ):
252
256
pytest .skip ("uses too much shared memory" )
@@ -270,8 +274,16 @@ def alloc_fn(size, align, stream):
270
274
BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , GROUP_SIZE_M , num_warps = num_warps ,
271
275
USE_FP8 = use_fp8 )
272
276
ttgir = kernel .asm ["ttgir" ]
273
- assert "ttng.tc_gen5_mma" in ttgir
274
- assert "ttg.warp_specialize" in ttgir
277
+ if is_blackwell ():
278
+ assert "ttng.tc_gen5_mma" in ttgir
279
+ assert "ttng.async_tma_copy_global_to_local" in ttgir
280
+ else :
281
+ assert "ttng.warp_group_dot" in ttgir
282
+ assert "ttng.async_tma_copy_global_to_local" in ttgir
283
+ if is_hopper () and num_warps == 8 :
284
+ assert "ttg.warp_specialize" not in ttgir
285
+ else :
286
+ assert "ttg.warp_specialize" in ttgir
275
287
276
288
ref_out = torch .empty ((M , N ), dtype = dtype , device = device )
277
289
cublas .matmul (A , B , ref_out )
@@ -326,11 +338,11 @@ def matmul_tma_persistent_ws_kernel( #
326
338
@pytest .mark .parametrize ("BLOCK_SIZE_M" , [128 ])
327
339
@pytest .mark .parametrize ("BLOCK_SIZE_N" , [128 , 256 ])
328
340
@pytest .mark .parametrize ("BLOCK_SIZE_K" , [64 , 128 ])
329
- @pytest .mark .parametrize ("num_stages" , [2 , 3 , 4 ])
341
+ @pytest .mark .parametrize ("num_stages" , [2 , 3 ])
330
342
@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
331
343
@pytest .mark .parametrize ("use_fp8" , [False , True ])
332
344
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
333
- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
345
+ @pytest .mark .skipif (not is_blackwell () , reason = "Requires Blackwell " )
334
346
def test_warp_specialize_tma_matmul_persistent (M , N , K , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , num_stages , num_warps ,
335
347
use_fp8 ):
336
348
if exceeds_smem_capacity (num_stages , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , use_fp8 ):
@@ -413,12 +425,12 @@ def attention_inner_loop_kernel( #
413
425
@pytest .mark .parametrize ("M, N" , [(8192 , 8192 ), (1024 , 1024 )])
414
426
@pytest .mark .parametrize ("BLOCK_M" , [64 , 128 ])
415
427
@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
416
- @pytest .mark .parametrize ("num_stages" , [2 , 3 , 4 ])
428
+ @pytest .mark .parametrize ("num_stages" , [2 , 3 ])
417
429
@pytest .mark .parametrize ("disable_acc_multibuf" , [False , True ])
418
430
@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
419
431
@pytest .mark .parametrize ("use_fp8" , [False , True ])
420
432
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
421
- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
433
+ @pytest .mark .skipif (not is_blackwell () , reason = "Requires Blackwell " )
422
434
def test_warp_specialize_attention_forward (M , N , BLOCK_M , HEAD_DIM , num_stages , disable_acc_multibuf , num_warps ,
423
435
use_fp8 ):
424
436
if BLOCK_M == 128 and HEAD_DIM == 128 and not use_fp8 :
0 commit comments