44import triton
55import triton .language as tl
66
7- from triton ._internal_testing import is_hip
7+ from triton ._internal_testing import is_hip , is_hopper , is_blackwell
88from triton .tools .tensor_descriptor import TensorDescriptor
99
10- if torch .cuda .is_available () and torch .cuda .get_device_capability ()[0 ] == 10 :
10+ if not is_hip () and torch .cuda .is_available () and torch .cuda .get_device_capability ()[0 ] in [ 9 , 10 ] :
1111 from triton ._C .libtriton import nvidia
1212 cublas_workspace = torch .empty (32 * 1024 * 1024 , device = "cuda" , dtype = torch .uint8 )
1313 cublas = nvidia .cublas .CublasLt (cublas_workspace )
1414else :
1515 cublas = None
1616
1717
18+ def is_hopper_or_blackwell ():
19+ return is_hopper () or is_blackwell ()
20+
21+
1822@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
19- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
23+ @pytest .mark .skipif (not is_hopper_or_blackwell () , reason = "Requires Hopper or Blackwell " )
2024def test_warp_specialize_basic_ir (tmp_path : pathlib .Path ):
2125 ir = """
2226 tt.func @kernel(%arg0: !tt.ptr<i32>) {
@@ -51,7 +55,7 @@ def test_warp_specialize_basic_ir(tmp_path: pathlib.Path):
5155
5256
5357@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
54- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
58+ @pytest .mark .skipif (not is_blackwell () , reason = "Requires Blackwell " )
5559def test_warp_specialize_tmem_ir (tmp_path : pathlib .Path ):
5660 ir = """
5761 #blocked = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}>
@@ -119,7 +123,7 @@ def test_warp_specialize_tmem_ir(tmp_path: pathlib.Path):
119123
120124
121125@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
122- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
126+ @pytest .mark .skipif (not is_hopper_or_blackwell () , reason = "Requires Hopper or Blackwell " )
123127def test_warpgroup_reduction (tmp_path : pathlib .Path ):
124128
125129 def template (i , num_warps , in_ptr , out_ptr ):
@@ -242,11 +246,11 @@ def exceeds_smem_capacity(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, use_fp8):
242246@pytest .mark .parametrize ("BLOCK_SIZE_M" , [128 ])
243247@pytest .mark .parametrize ("BLOCK_SIZE_N" , [128 , 256 ])
244248@pytest .mark .parametrize ("BLOCK_SIZE_K" , [64 , 128 ])
245- @pytest .mark .parametrize ("num_stages" , [2 , 3 , 4 ])
249+ @pytest .mark .parametrize ("num_stages" , [2 , 3 ])
246250@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
247251@pytest .mark .parametrize ("use_fp8" , [False , True ])
248252@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
249- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
253+ @pytest .mark .skipif (not is_hopper_or_blackwell () , reason = "Requires Hopper or Blackwell " )
250254def test_warp_specialize_tma_matmul (M , N , K , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , num_stages , num_warps , use_fp8 ):
251255 if exceeds_smem_capacity (num_stages , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , use_fp8 = use_fp8 ):
252256 pytest .skip ("uses too much shared memory" )
@@ -270,8 +274,16 @@ def alloc_fn(size, align, stream):
270274 BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , GROUP_SIZE_M , num_warps = num_warps ,
271275 USE_FP8 = use_fp8 )
272276 ttgir = kernel .asm ["ttgir" ]
273- assert "ttng.tc_gen5_mma" in ttgir
274- assert "ttg.warp_specialize" in ttgir
277+ if is_blackwell ():
278+ assert "ttng.tc_gen5_mma" in ttgir
279+ assert "ttng.async_tma_copy_global_to_local" in ttgir
280+ else :
281+ assert "ttng.warp_group_dot" in ttgir
282+ assert "ttng.async_tma_copy_global_to_local" in ttgir
283+ if is_hopper () and num_warps == 8 :
284+ assert "ttg.warp_specialize" not in ttgir
285+ else :
286+ assert "ttg.warp_specialize" in ttgir
275287
276288 ref_out = torch .empty ((M , N ), dtype = dtype , device = device )
277289 cublas .matmul (A , B , ref_out )
@@ -326,11 +338,11 @@ def matmul_tma_persistent_ws_kernel( #
326338@pytest .mark .parametrize ("BLOCK_SIZE_M" , [128 ])
327339@pytest .mark .parametrize ("BLOCK_SIZE_N" , [128 , 256 ])
328340@pytest .mark .parametrize ("BLOCK_SIZE_K" , [64 , 128 ])
329- @pytest .mark .parametrize ("num_stages" , [2 , 3 , 4 ])
341+ @pytest .mark .parametrize ("num_stages" , [2 , 3 ])
330342@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
331343@pytest .mark .parametrize ("use_fp8" , [False , True ])
332344@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
333- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
345+ @pytest .mark .skipif (not is_blackwell () , reason = "Requires Blackwell " )
334346def test_warp_specialize_tma_matmul_persistent (M , N , K , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , num_stages , num_warps ,
335347 use_fp8 ):
336348 if exceeds_smem_capacity (num_stages , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , use_fp8 ):
@@ -413,12 +425,12 @@ def attention_inner_loop_kernel( #
413425@pytest .mark .parametrize ("M, N" , [(8192 , 8192 ), (1024 , 1024 )])
414426@pytest .mark .parametrize ("BLOCK_M" , [64 , 128 ])
415427@pytest .mark .parametrize ("HEAD_DIM" , [64 , 128 ])
416- @pytest .mark .parametrize ("num_stages" , [2 , 3 , 4 ])
428+ @pytest .mark .parametrize ("num_stages" , [2 , 3 ])
417429@pytest .mark .parametrize ("disable_acc_multibuf" , [False , True ])
418430@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
419431@pytest .mark .parametrize ("use_fp8" , [False , True ])
420432@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
421- @pytest .mark .skipif (torch . cuda . get_device_capability ()[ 0 ] != 10 , reason = "Requires compute capability == 10 " )
433+ @pytest .mark .skipif (not is_blackwell () , reason = "Requires Blackwell " )
422434def test_warp_specialize_attention_forward (M , N , BLOCK_M , HEAD_DIM , num_stages , disable_acc_multibuf , num_warps ,
423435 use_fp8 ):
424436 if BLOCK_M == 128 and HEAD_DIM == 128 and not use_fp8 :
0 commit comments