@@ -343,11 +343,13 @@ def grouped_matmul_kernel(
343343 group_id = pid // num_pid_in_group
344344 first_pid_m = group_id * GROUP_SIZE_M
345345 group_size_m = min (num_pid_m - first_pid_m , GROUP_SIZE_M )
346- in_group_index = pid % num_pid_in_group
347- back_mark = (in_group_index // group_size_m ) % 2
348- back_mark1 = - 1 * (2 * back_mark - 1 )
349- pid_m = first_pid_m + back_mark * (group_size_m - 1 ) + back_mark1 * (in_group_index % group_size_m )
346+ pid_m = first_pid_m + pid % num_pid_in_group % group_size_m
350347 pid_n = (pid % num_pid_in_group ) // group_size_m
348+ # in_group_index = pid % num_pid_in_group
349+ # back_mark = (in_group_index // group_size_m) % 2
350+ # back_mark1 = -1 * (2 * back_mark - 1)
351+ # pid_m = first_pid_m + back_mark * (group_size_m - 1) + back_mark1 * (in_group_index % group_size_m)
352+ # pid_n = (pid % num_pid_in_group) // group_size_m
351353
352354 expert_id = tl .load (mblocks_to_expert_id + pid_m )
353355
@@ -488,12 +490,12 @@ def _get_grouped_matmul_configs():
488490 "num_stages" : ns ,
489491 "NEED_TRANS" : need_trans ,
490492 }
491- for ns in [1 , 2 , 3 , 4 , 5 ]
492- for gm in [1 , 2 , 4 , 8 ]
493- for nw in [2 , 4 , 8 ]
493+ for ns in [2 , 3 , 4 , 5 ]
494+ for gm in [1 , 16 , 32 , 64 ]
495+ for nw in [4 , 8 ]
494496 for bm in [16 , 32 , 64 , 128 ]
495497 for bn in [16 , 32 , 64 , 128 ]
496- for bk in [16 , 32 , 64 , 128 ]
498+ for bk in [32 , 64 , 128 ]
497499 for need_trans in [True , False ]
498500 ]
499501
0 commit comments