@@ -304,6 +304,7 @@ def matmul_tma_persistent_ws_kernel( #
304
304
GROUP_SIZE_M : tl .constexpr , #
305
305
NUM_SMS : tl .constexpr , #
306
306
USE_FP8 : tl .constexpr , #
307
+ FLATTEN : tl .constexpr , #
307
308
):
308
309
a_desc = tl .make_tensor_descriptor (a_ptr , shape = [M , K ], strides = [a_stride0 , a_stride1 ],
309
310
block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_K ])
@@ -318,7 +319,8 @@ def matmul_tma_persistent_ws_kernel( #
318
319
k_tiles = tl .cdiv (K , BLOCK_SIZE_K )
319
320
num_tiles = num_pid_m * num_pid_n
320
321
321
- for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = True , num_stages = num_stages ):
322
+ for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = FLATTEN , warp_specialize = True ,
323
+ num_stages = num_stages ):
322
324
pid_m , pid_n = _compute_pid (tile_id , num_pid_n , num_pid_m , GROUP_SIZE_M )
323
325
324
326
off_am = pid_m * BLOCK_SIZE_M
@@ -342,7 +344,7 @@ def matmul_tma_persistent_ws_kernel( #
342
344
@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
343
345
@pytest .mark .parametrize ("use_fp8" , [False , True ])
344
346
@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
345
- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
347
+ @pytest .mark .skipif (not is_hopper_or_blackwell (), reason = "Requires Hopper or Blackwell" )
346
348
def test_warp_specialize_tma_matmul_persistent (M , N , K , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , num_stages , num_warps ,
347
349
use_fp8 ):
348
350
if exceeds_smem_capacity (num_stages , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , use_fp8 ):
@@ -371,10 +373,18 @@ def grid(META):
371
373
372
374
kernel = matmul_tma_persistent_ws_kernel [grid ](A , B , C , * A .stride (), * B .stride (), * C .stride (), M , N , K , num_stages ,
373
375
BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , GROUP_SIZE_M , NUM_SMS ,
374
- num_warps = num_warps , USE_FP8 = use_fp8 )
376
+ num_warps = num_warps , USE_FP8 = use_fp8 , FLATTEN = is_blackwell () )
375
377
ttgir = kernel .asm ["ttgir" ]
376
- assert "ttng.tc_gen5_mma" in ttgir
377
- assert "ttg.warp_specialize" in ttgir
378
+ if is_blackwell ():
379
+ assert "ttng.tc_gen5_mma" in ttgir
380
+ assert "ttng.async_tma_copy_global_to_local" in ttgir
381
+ else :
382
+ assert "ttng.warp_group_dot" in ttgir
383
+ assert "ttng.async_tma_copy_global_to_local" in ttgir
384
+ if is_hopper () and num_warps == 8 :
385
+ assert "ttg.warp_specialize" not in ttgir
386
+ else :
387
+ assert "ttg.warp_specialize" in ttgir
378
388
379
389
ref_out = torch .empty ((M , N ), dtype = dtype , device = device )
380
390
cublas .matmul (A , B , ref_out )
0 commit comments