Skip to content

Commit b81de0e

Browse files
author
wangzaijun
committed
fix kernel
1 parent bb1b3f2 commit b81de0e

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,14 +481,20 @@ def grouped_matmul_kernel(
481481

482482
if use_fp8_w8a8:
483483
if block_size_k > 0 and block_size_n > 0:
484+
assert BLOCK_SIZE_K <= block_size_k
484485
token_scale_stride0 = token_stride_0 // block_size_k
485486
if TOKEN_INPUT_USE_TMA:
486487
assert MUL_ROUTED_WEIGHT is True
487488
a_scale_ptrs = token_scale_ptr + (token_start_index + tl.arange(0, BLOCK_SIZE_M)) * token_scale_stride0
488489
else:
489490
a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * token_scale_stride0
490491

491-
offs_bsn = offs_bn // block_size_n
492+
if BLOCK_SIZE_N > block_size_n:
493+
offs_bsn = offs_bn // block_size_n
494+
else:
495+
# single b scale
496+
offs_bsn = (tile_n_idx * BLOCK_SIZE_N) // block_size_n
497+
492498
b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1
493499
else:
494500
a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last")
@@ -556,9 +562,16 @@ def grouped_matmul_kernel(
556562
a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0)
557563
b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2)
558564
if NEED_TRANS:
559-
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
565+
if BLOCK_SIZE_N > block_size_n:
566+
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
567+
else:
568+
# single b scale
569+
accumulator += tl.dot(b, a) * (a_scale[None, :] * b_scale)
560570
else:
561-
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
571+
if BLOCK_SIZE_N > block_size_n:
572+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
573+
else:
574+
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
562575
else:
563576
if NEED_TRANS:
564577
accumulator = tl.dot(b, a, acc=accumulator)

0 commit comments

Comments
 (0)