@@ -39,6 +39,7 @@ def moe_gemm_kernel(
3939 EM : tl .constexpr ,
4040 N : tl .constexpr ,
4141 K : tl .constexpr ,
42+ EVEN_K : tl .constexpr ,
4243 MUL_ROUTED_WEIGHT : tl .constexpr ,
4344 BLOCK_SIZE_M : tl .constexpr ,
4445 BLOCK_SIZE_N : tl .constexpr ,
@@ -97,8 +98,12 @@ def moe_gemm_kernel(
9798
9899 for k in range (0 , tl .cdiv (K , BLOCK_SIZE_K )):
99100 # Masking ensures we don't load from invalid tokens or indices
100- a = tl .load (a_ptrs , mask = (token_mask [:, None ] & (offs_k [None , :] < K - k * BLOCK_SIZE_K )), other = 0.0 )
101- b = tl .load (b_ptrs , mask = (offs_k [:, None ] < K - k * BLOCK_SIZE_K ), other = 0.0 )
101+ if EVEN_K :
102+ a = tl .load (a_ptrs , mask = (token_mask [:, None ]), other = 0.0 )
103+ b = tl .load (b_ptrs )
104+ else :
105+ a = tl .load (a_ptrs , mask = (token_mask [:, None ] & (offs_k [None , :] < K - k * BLOCK_SIZE_K )), other = 0.0 )
106+ b = tl .load (b_ptrs , mask = (offs_k [:, None ] < K - k * BLOCK_SIZE_K ), other = 0.0 )
102107
103108 accumulator = tl .dot (a , b , acc = accumulator )
104109 a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -292,8 +297,10 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to
292297 _ , N , K = b .shape
293298 grid = lambda META : (triton .cdiv (EM , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
294299
300+ EVEN_K = K % config ["BLOCK_SIZE_K" ] == 0
301+
295302 moe_gemm_kernel [grid ](a , b , c , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), b .stride (2 ), c .stride (1 ),
296- c .stride (2 ), top_k , topk_weights , sorted_token_ids , expert_ids , EM , N , K ,
303+ c .stride (2 ), top_k , topk_weights , sorted_token_ids , expert_ids , EM , N , K , EVEN_K ,
297304 MUL_ROUTED_WEIGHT = topk_weights is not None , ** config )
298305 return c
299306
@@ -327,6 +334,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
327334 (64 , 14336 , 4096 , 2 , 8 ),
328335 (16 , 14336 , 1 , 2 , 4 ),
329336 (1 , 14336 , 128 , 2 , 4 ),
337+ (3 , 14336 , 128 , 2 , 4 ),
330338 (16 , 14336 , 128 , 1 , 4 ),
331339 (16 , 14336 , 128 , 1 , 1 ),
332340 (64 , 7186 , 128 , 2 , 8 ),
0 commit comments