@@ -481,14 +481,20 @@ def grouped_matmul_kernel(
481481
482482 if use_fp8_w8a8 :
483483 if block_size_k > 0 and block_size_n > 0 :
484+ assert BLOCK_SIZE_K <= block_size_k
484485 token_scale_stride0 = token_stride_0 // block_size_k
485486 if TOKEN_INPUT_USE_TMA :
486487 assert MUL_ROUTED_WEIGHT is True
487488 a_scale_ptrs = token_scale_ptr + (token_start_index + tl .arange (0 , BLOCK_SIZE_M )) * token_scale_stride0
488489 else :
489490 a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num ) * token_scale_stride0
490491
491- offs_bsn = offs_bn // block_size_n
492+ if BLOCK_SIZE_N > block_size_n :
493+ offs_bsn = offs_bn // block_size_n
494+ else :
495+ # single b scale
496+ offs_bsn = (tile_n_idx * BLOCK_SIZE_N ) // block_size_n
497+
492498 b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1
493499 else :
494500 a_scale = tl .load (token_scale_ptr , eviction_policy = "evict_last" )
@@ -556,9 +562,16 @@ def grouped_matmul_kernel(
556562 a_scale = tl .load (a_scale_ptrs + offs_ks , mask = token_mask , other = 0.0 )
557563 b_scale = tl .load (b_scale_ptrs + offs_ks * weight_scale_stride2 )
558564 if NEED_TRANS :
559- accumulator += tl .dot (b , a ) * b_scale [:, None ] * a_scale [None , :]
565+ if BLOCK_SIZE_N > block_size_n :
566+ accumulator += tl .dot (b , a ) * b_scale [:, None ] * a_scale [None , :]
567+ else :
568+ # single b scale
569+ accumulator += tl .dot (b , a ) * (a_scale [None , :] * b_scale )
560570 else :
561- accumulator += tl .dot (a , b ) * a_scale [:, None ] * b_scale [None , :]
571+ if BLOCK_SIZE_N > block_size_n :
572+ accumulator += tl .dot (a , b ) * a_scale [:, None ] * b_scale [None , :]
573+ else :
574+ accumulator += tl .dot (a , b ) * (a_scale [:, None ] * b_scale )
562575 else :
563576 if NEED_TRANS :
564577 accumulator = tl .dot (b , a , acc = accumulator )
0 commit comments