@@ -368,13 +368,6 @@ def grouped_matmul_kernel(
368368 mask = token_mask ,
369369 other = 0 ,
370370 )
371- if MUL_ROUTED_WEIGHT :
372- a_m_scale = tl .load (
373- expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am ,
374- mask = token_mask ,
375- other = 0.0 ,
376- )
377-
378371 offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % n
379372 offs_k = tl .arange (0 , BLOCK_SIZE_K )
380373
@@ -404,14 +397,18 @@ def grouped_matmul_kernel(
404397
405398 if NEED_TRANS :
406399 if NEED_K_MASK :
407- a = tl .load (a_ptrs , mask = (token_mask [None , :]) & (offs_k [:, None ] < k ), other = 0.0 )
400+ a = tl .load (
401+ a_ptrs , mask = (token_mask [None , :]) & (offs_k [:, None ] < k - step_k * BLOCK_SIZE_K ), other = 0.0
402+ )
408403 b = tl .load (b_ptrs , mask = (offs_k [None , :] < k ), other = 0.0 )
409404 else :
410405 a = tl .load (a_ptrs , mask = (token_mask [None , :]), other = 0.0 )
411406 b = tl .load (b_ptrs )
412407 else :
413408 if NEED_K_MASK :
414- a = tl .load (a_ptrs , mask = (token_mask [:, None ]) & (offs_k [None , :] < k ), other = 0.0 )
409+ a = tl .load (
410+ a_ptrs , mask = (token_mask [:, None ]) & (offs_k [None , :] < k - step_k * BLOCK_SIZE_K ), other = 0.0
411+ )
415412 b = tl .load (b_ptrs , mask = (offs_k [:, None ] < k ), other = 0.0 )
416413 else :
417414 a = tl .load (a_ptrs , mask = (token_mask [:, None ]), other = 0.0 )
@@ -436,7 +433,6 @@ def grouped_matmul_kernel(
436433
437434 a_ptrs += BLOCK_SIZE_K
438435 b_ptrs += BLOCK_SIZE_K
439- offs_k += BLOCK_SIZE_K
440436
441437 if NEED_TRANS :
442438 accumulator = accumulator .T
@@ -446,6 +442,11 @@ def grouped_matmul_kernel(
446442 accumulator *= ab_scale
447443
448444 if MUL_ROUTED_WEIGHT :
445+ a_m_scale = tl .load (
446+ expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am ,
447+ mask = token_mask ,
448+ other = 0.0 ,
449+ )
449450 accumulator *= a_m_scale [:, None ]
450451
451452 c = accumulator .to (compute_type )
0 commit comments