@@ -332,6 +332,7 @@ def grouped_matmul_kernel(
332332 GROUP_SIZE_M : tl .constexpr ,
333333 MUL_ROUTED_WEIGHT : tl .constexpr = False ,
334334 NEED_K_MASK : tl .constexpr = True ,
335+ NEED_TRANS : tl .constexpr = False ,
335336):
336337 pid = tl .program_id (0 )
337338
@@ -387,7 +388,7 @@ def grouped_matmul_kernel(
387388 b_scale = tl .load (weight_scale_ptr + expert_id , eviction_policy = "evict_last" )
388389 ab_scale = a_scale * b_scale
389390
390- if use_fp8_w8a8 :
391+ if NEED_TRANS :
391392 a_ptrs = token_ptr + (a_m_index // topk_num )[None , :] * token_stride_0 + offs_k [:, None ]
392393 b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k [None , :] + offs_bn [:, None ] * weight_stride_1
393394 accumulator = tl .zeros ((BLOCK_SIZE_N , BLOCK_SIZE_M ), dtype = tl .float32 )
@@ -401,7 +402,7 @@ def grouped_matmul_kernel(
401402 # tl.multiple_of(a_ptrs, [16, 16])
402403 # tl.multiple_of(b_ptrs, [16, 16])
403404
404- if use_fp8_w8a8 :
405+ if NEED_TRANS :
405406 if NEED_K_MASK :
406407 a = tl .load (a_ptrs , mask = (token_mask [None , :]) & (offs_k [:, None ] < k ), other = 0.0 )
407408 b = tl .load (b_ptrs , mask = (offs_k [None , :] < k ), other = 0.0 )
@@ -421,21 +422,27 @@ def grouped_matmul_kernel(
421422 offs_ks = step_k * BLOCK_SIZE_K // block_size_k
422423 a_scale = tl .load (a_scale_ptrs + offs_ks , mask = token_mask , other = 0.0 )
423424 b_scale = tl .load (b_scale_ptrs + offs_ks * weight_scale_stride2 )
424- accumulator += tl .dot (b , a ) * b_scale [:, None ] * a_scale [None , :]
425+ if NEED_TRANS :
426+ accumulator += tl .dot (b , a ) * b_scale [:, None ] * a_scale [None , :]
427+ else :
428+ accumulator += tl .dot (a , b ) * a_scale [:, None ] * b_scale [None , :]
425429 else :
426- accumulator = tl .dot (b , a , acc = accumulator )
430+ if NEED_TRANS :
431+ accumulator = tl .dot (b , a , acc = accumulator )
432+ else :
433+ accumulator = tl .dot (a , b , acc = accumulator )
427434 else :
428435 accumulator += tl .dot (a , b )
429436
430437 a_ptrs += BLOCK_SIZE_K
431438 b_ptrs += BLOCK_SIZE_K
432439 offs_k += BLOCK_SIZE_K
433440
441+ if NEED_TRANS :
442+ accumulator = accumulator .T
443+
434444 if use_fp8_w8a8 :
435- if block_size_k > 0 and block_size_n > 0 :
436- accumulator = accumulator .T
437- else :
438- accumulator = accumulator .T
445+ if not (block_size_k > 0 and block_size_n > 0 ):
439446 accumulator *= ab_scale
440447
441448 if MUL_ROUTED_WEIGHT :
@@ -478,13 +485,15 @@ def _get_grouped_matmul_configs():
478485 "GROUP_SIZE_M" : gm ,
479486 "num_warps" : nw ,
480487 "num_stages" : ns ,
488+ "need_trans" : need_trans ,
481489 }
482490 for ns in [1 , 2 , 3 , 4 , 5 ]
483491 for gm in [1 , 2 , 4 , 8 ]
484492 for nw in [2 , 4 , 8 ]
485493 for bm in [16 , 32 , 64 , 128 ]
486494 for bn in [16 , 32 , 64 , 128 ]
487495 for bk in [16 , 32 , 64 , 128 ]
496+ for need_trans in [True , False ]
488497 ]
489498
490499
@@ -559,6 +568,7 @@ def grouped_matmul(
559568 GROUP_SIZE_M = run_config ["GROUP_SIZE_M" ]
560569 num_warps = run_config ["num_warps" ]
561570 num_stages = run_config ["num_stages" ]
571+ NEED_TRANS = run_config .get ("NEED_TRANS" , False )
562572
563573 if block_size_k != 0 :
564574 # 如果使用了 block wise 量化,分块大小不能超过 block size
@@ -638,6 +648,7 @@ def grouped_matmul(
638648 GROUP_SIZE_M = GROUP_SIZE_M ,
639649 MUL_ROUTED_WEIGHT = mul_routed_weight ,
640650 NEED_K_MASK = NEED_K_MASK ,
651+ NEED_TRANS = NEED_TRANS ,
641652 num_warps = num_warps ,
642653 num_stages = num_stages ,
643654 )
0 commit comments