@@ -61,17 +61,19 @@ def save_config(cls, N: int, K: int, out_dtype: str, config_json: Dict[int, Dict
6161
6262
6363@triton .jit
64- def grouped_launch (pid , m , n , block_m : tl . constexpr , block_n : tl . constexpr , group_m : tl .constexpr ):
64+ def grouped_launch (pid , m_block_num , n_block_num , group_m : tl .constexpr ):
6565
66- grid_m = tl .cdiv (m , block_m )
67- grid_n = tl .cdiv (n , block_n )
66+ num_pid_in_group = group_m * n_block_num
67+ group_id = pid // num_pid_in_group
68+ first_pid_m = group_id * group_m
69+ group_size_m = tl .minimum (m_block_num - first_pid_m , group_m )
70+ in_group_index = pid % num_pid_in_group
6871
69- width = group_m * grid_n
70- group_id = pid // width
71- group_size = tl .minimum (grid_m - group_id * group_m , group_m )
72-
73- pid_m = group_id * group_m + (pid % group_size )
74- pid_n = (pid % width ) // group_size
72+ # Swizzle pattern: zigzag traversal
73+ back_mark = (in_group_index // group_size_m ) % 2
74+ back_mark1 = - 1 * (2 * back_mark - 1 )
75+ pid_m = first_pid_m + back_mark * (group_size_m - 1 ) + back_mark1 * (in_group_index % group_size_m )
76+ pid_n = (pid % num_pid_in_group ) // group_size_m
7577
7678 return pid_m , pid_n
7779
@@ -89,6 +91,8 @@ def _scaled_mm_per_token(
8991 M ,
9092 N ,
9193 K ,
94+ m_block_num ,
95+ n_block_num ,
9296 stride_am ,
9397 stride_ak ,
9498 stride_bk ,
@@ -105,7 +109,7 @@ def _scaled_mm_per_token(
105109 GROUP_M : tl .constexpr ,
106110):
107111 pid = tl .program_id (0 )
108- pid_m , pid_n = grouped_launch (pid , M , N , BLOCK_M , BLOCK_N , GROUP_M )
112+ pid_m , pid_n = grouped_launch (pid , m_block_num , n_block_num , GROUP_M )
109113
110114 start_m = pid_m * BLOCK_M
111115 start_n = pid_n * BLOCK_N
@@ -289,6 +293,8 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
289293 M = M ,
290294 N = N ,
291295 K = K ,
296+ m_block_num = triton .cdiv (M , BLOCK_M ),
297+ n_block_num = triton .cdiv (N , BLOCK_N ),
292298 stride_am = A .stride (0 ),
293299 stride_ak = A .stride (1 ),
294300 stride_bk = B .stride (0 ),
0 commit comments