Skip to content

Commit 239f36f

Browse files
committed
update
1 parent 94a8c38 commit 239f36f

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -368,13 +368,6 @@ def grouped_matmul_kernel(
368368
mask=token_mask,
369369
other=0,
370370
)
371-
if MUL_ROUTED_WEIGHT:
372-
a_m_scale = tl.load(
373-
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
374-
mask=token_mask,
375-
other=0.0,
376-
)
377-
378371
offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
379372
offs_k = tl.arange(0, BLOCK_SIZE_K)
380373

@@ -404,14 +397,18 @@ def grouped_matmul_kernel(
404397

405398
if NEED_TRANS:
406399
if NEED_K_MASK:
407-
a = tl.load(a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k), other=0.0)
400+
a = tl.load(
401+
a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k - step_k * BLOCK_SIZE_K), other=0.0
402+
)
408403
b = tl.load(b_ptrs, mask=(offs_k[None, :] < k), other=0.0)
409404
else:
410405
a = tl.load(a_ptrs, mask=(token_mask[None, :]), other=0.0)
411406
b = tl.load(b_ptrs)
412407
else:
413408
if NEED_K_MASK:
414-
a = tl.load(a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k), other=0.0)
409+
a = tl.load(
410+
a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k - step_k * BLOCK_SIZE_K), other=0.0
411+
)
415412
b = tl.load(b_ptrs, mask=(offs_k[:, None] < k), other=0.0)
416413
else:
417414
a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0)
@@ -436,7 +433,6 @@ def grouped_matmul_kernel(
436433

437434
a_ptrs += BLOCK_SIZE_K
438435
b_ptrs += BLOCK_SIZE_K
439-
offs_k += BLOCK_SIZE_K
440436

441437
if NEED_TRANS:
442438
accumulator = accumulator.T
@@ -446,6 +442,11 @@ def grouped_matmul_kernel(
446442
accumulator *= ab_scale
447443

448444
if MUL_ROUTED_WEIGHT:
445+
a_m_scale = tl.load(
446+
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
447+
mask=token_mask,
448+
other=0.0,
449+
)
449450
accumulator *= a_m_scale[:, None]
450451

451452
c = accumulator.to(compute_type)

lightllm/common/fused_moe/moe_kernel_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def try_to_get_best_config(
4646
"BLOCK_SIZE_N": 32,
4747
"BLOCK_SIZE_K": 64,
4848
"GROUP_SIZE_M": 1,
49+
"NEED_TRANS": False,
4950
"num_warps": 4,
5051
"num_stages": 1,
5152
}
@@ -55,6 +56,7 @@ def try_to_get_best_config(
5556
"BLOCK_SIZE_N": 64,
5657
"BLOCK_SIZE_K": 32,
5758
"GROUP_SIZE_M": 8,
59+
"NEED_TRANS": False,
5860
"num_warps": 4,
5961
"num_stages": 1,
6062
}

0 commit comments

Comments
 (0)