Skip to content

Commit 3f3a23c

Browse files
author
wangzaijun
committed
fix kernel, remove dead code
1 parent 692bfbd commit 3f3a23c

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -457,22 +457,23 @@ def grouped_matmul_kernel(
457457
tile_m_idx = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 1)
458458
tile_n_idx = pid_n
459459

460-
if OUT_SORTED or TOKEN_INPUT_USE_TMA:
461-
assert OUT_SORTED and TOKEN_INPUT_USE_TMA is False
462-
# get token start index in inputs
463-
token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2)
460+
# get token start index in inputs
461+
token_start_index = tl.load(mblocks_to_tuple_info + pid_m * mblocks_to_tuple_info_stride_0 + 2)
464462

465463
# get the gemm size of the current problem
466464
cur_m = tl.load(expert_to_token_num + expert_id)
467465

468466
# do regular gemm here
469467
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
470468
token_mask = offs_am < cur_m
471-
a_m_index = tl.load(
472-
expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am,
473-
mask=token_mask,
474-
other=0,
475-
)
469+
470+
if not OUT_SORTED or not TOKEN_INPUT_USE_TMA:
471+
a_m_index = tl.load(
472+
expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am,
473+
mask=token_mask,
474+
other=0,
475+
)
476+
476477
offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
477478
offs_k = tl.arange(0, BLOCK_SIZE_K)
478479

@@ -493,12 +494,17 @@ def grouped_matmul_kernel(
493494
ab_scale = a_scale * b_scale
494495

495496
if NEED_TRANS:
496-
a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None]
497-
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1
497+
if not TOKEN_INPUT_USE_TMA:
498+
a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None]
499+
if not WEIGHT_USE_TMA:
500+
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1
498501
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
499502
else:
500-
a_ptrs = token_ptr + (a_m_index // topk_num)[:, None] * token_stride_0 + offs_k[None, :]
501-
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[:, None] + offs_bn[None, :] * weight_stride_1
503+
if not TOKEN_INPUT_USE_TMA:
504+
a_ptrs = token_ptr + (a_m_index // topk_num)[:, None] * token_stride_0 + offs_k[None, :]
505+
if not WEIGHT_USE_TMA:
506+
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[:, None] + offs_bn[None, :] * weight_stride_1
507+
502508
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
503509

504510
for k_start in range(0, k, BLOCK_SIZE_K):
@@ -559,8 +565,10 @@ def grouped_matmul_kernel(
559565
else:
560566
accumulator += tl.dot(a, b)
561567

562-
a_ptrs += BLOCK_SIZE_K
563-
b_ptrs += BLOCK_SIZE_K
568+
if not TOKEN_INPUT_USE_TMA:
569+
a_ptrs += BLOCK_SIZE_K
570+
if not WEIGHT_USE_TMA:
571+
b_ptrs += BLOCK_SIZE_K
564572

565573
if NEED_TRANS:
566574
accumulator = accumulator.T

0 commit comments

Comments
 (0)