@@ -304,6 +304,7 @@ def matmul_tma_persistent_ws_kernel( #
304304 GROUP_SIZE_M : tl .constexpr , #
305305 NUM_SMS : tl .constexpr , #
306306 USE_FP8 : tl .constexpr , #
307+ FLATTEN : tl .constexpr , #
307308):
308309 a_desc = tl .make_tensor_descriptor (a_ptr , shape = [M , K ], strides = [a_stride0 , a_stride1 ],
309310 block_shape = [BLOCK_SIZE_M , BLOCK_SIZE_K ])
@@ -318,7 +319,8 @@ def matmul_tma_persistent_ws_kernel( #
318319 k_tiles = tl .cdiv (K , BLOCK_SIZE_K )
319320 num_tiles = num_pid_m * num_pid_n
320321
321- for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = True , warp_specialize = True , num_stages = num_stages ):
322+ for tile_id in tl .range (start_pid , num_tiles , NUM_SMS , flatten = FLATTEN , warp_specialize = True ,
323+ num_stages = num_stages ):
322324 pid_m , pid_n = _compute_pid (tile_id , num_pid_n , num_pid_m , GROUP_SIZE_M )
323325
324326 off_am = pid_m * BLOCK_SIZE_M
@@ -342,7 +344,7 @@ def matmul_tma_persistent_ws_kernel( #
342344@pytest .mark .parametrize ("num_warps" , [4 , 8 ])
343345@pytest .mark .parametrize ("use_fp8" , [False , True ])
344346@pytest .mark .skipif (is_hip (), reason = "warp specialization is not supported on hip devices" )
345- @pytest .mark .skipif (not is_blackwell (), reason = "Requires Blackwell" )
347+ @pytest .mark .skipif (not is_hopper_or_blackwell (), reason = "Requires Hopper or Blackwell" )
346348def test_warp_specialize_tma_matmul_persistent (M , N , K , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , num_stages , num_warps ,
347349 use_fp8 ):
348350 if exceeds_smem_capacity (num_stages , BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , use_fp8 ):
@@ -371,10 +373,18 @@ def grid(META):
371373
372374 kernel = matmul_tma_persistent_ws_kernel [grid ](A , B , C , * A .stride (), * B .stride (), * C .stride (), M , N , K , num_stages ,
373375 BLOCK_SIZE_M , BLOCK_SIZE_N , BLOCK_SIZE_K , GROUP_SIZE_M , NUM_SMS ,
374- num_warps = num_warps , USE_FP8 = use_fp8 )
376+ num_warps = num_warps , USE_FP8 = use_fp8 , FLATTEN = is_blackwell () )
375377 ttgir = kernel .asm ["ttgir" ]
376- assert "ttng.tc_gen5_mma" in ttgir
377- assert "ttg.warp_specialize" in ttgir
378+ if is_blackwell ():
379+ assert "ttng.tc_gen5_mma" in ttgir
380+ assert "ttng.async_tma_copy_global_to_local" in ttgir
381+ else :
382+ assert "ttng.warp_group_dot" in ttgir
383+ assert "ttng.async_tma_copy_global_to_local" in ttgir
384+ if is_hopper () and num_warps == 8 :
385+ assert "ttg.warp_specialize" not in ttgir
386+ else :
387+ assert "ttg.warp_specialize" in ttgir
378388
379389 ref_out = torch .empty ((M , N ), dtype = dtype , device = device )
380390 cublas .matmul (A , B , ref_out )
0 commit comments