Skip to content

Commit 94a8c38

Browse files
committed
add NEED_TRANS
1 parent 237ae00 commit 94a8c38

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def grouped_matmul_kernel(
332332
GROUP_SIZE_M: tl.constexpr,
333333
MUL_ROUTED_WEIGHT: tl.constexpr = False,
334334
NEED_K_MASK: tl.constexpr = True,
335+
NEED_TRANS: tl.constexpr = False,
335336
):
336337
pid = tl.program_id(0)
337338

@@ -387,7 +388,7 @@ def grouped_matmul_kernel(
387388
b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last")
388389
ab_scale = a_scale * b_scale
389390

390-
if use_fp8_w8a8:
391+
if NEED_TRANS:
391392
a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None]
392393
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1
393394
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
@@ -401,7 +402,7 @@ def grouped_matmul_kernel(
401402
# tl.multiple_of(a_ptrs, [16, 16])
402403
# tl.multiple_of(b_ptrs, [16, 16])
403404

404-
if use_fp8_w8a8:
405+
if NEED_TRANS:
405406
if NEED_K_MASK:
406407
a = tl.load(a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k), other=0.0)
407408
b = tl.load(b_ptrs, mask=(offs_k[None, :] < k), other=0.0)
@@ -421,21 +422,27 @@ def grouped_matmul_kernel(
421422
offs_ks = step_k * BLOCK_SIZE_K // block_size_k
422423
a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0)
423424
b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2)
424-
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
425+
if NEED_TRANS:
426+
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
427+
else:
428+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
425429
else:
426-
accumulator = tl.dot(b, a, acc=accumulator)
430+
if NEED_TRANS:
431+
accumulator = tl.dot(b, a, acc=accumulator)
432+
else:
433+
accumulator = tl.dot(a, b, acc=accumulator)
427434
else:
428435
accumulator += tl.dot(a, b)
429436

430437
a_ptrs += BLOCK_SIZE_K
431438
b_ptrs += BLOCK_SIZE_K
432439
offs_k += BLOCK_SIZE_K
433440

441+
if NEED_TRANS:
442+
accumulator = accumulator.T
443+
434444
if use_fp8_w8a8:
435-
if block_size_k > 0 and block_size_n > 0:
436-
accumulator = accumulator.T
437-
else:
438-
accumulator = accumulator.T
445+
if not (block_size_k > 0 and block_size_n > 0):
439446
accumulator *= ab_scale
440447

441448
if MUL_ROUTED_WEIGHT:
@@ -478,13 +485,15 @@ def _get_grouped_matmul_configs():
478485
"GROUP_SIZE_M": gm,
479486
"num_warps": nw,
480487
"num_stages": ns,
488+
"need_trans": need_trans,
481489
}
482490
for ns in [1, 2, 3, 4, 5]
483491
for gm in [1, 2, 4, 8]
484492
for nw in [2, 4, 8]
485493
for bm in [16, 32, 64, 128]
486494
for bn in [16, 32, 64, 128]
487495
for bk in [16, 32, 64, 128]
496+
for need_trans in [True, False]
488497
]
489498

490499

@@ -559,6 +568,7 @@ def grouped_matmul(
559568
GROUP_SIZE_M = run_config["GROUP_SIZE_M"]
560569
num_warps = run_config["num_warps"]
561570
num_stages = run_config["num_stages"]
571+
NEED_TRANS = run_config.get("NEED_TRANS", False)
562572

563573
if block_size_k != 0:
564574
# 如果使用了 block wise 量化,分块大小不能超过 block size
@@ -638,6 +648,7 @@ def grouped_matmul(
638648
GROUP_SIZE_M=GROUP_SIZE_M,
639649
MUL_ROUTED_WEIGHT=mul_routed_weight,
640650
NEED_K_MASK=NEED_K_MASK,
651+
NEED_TRANS=NEED_TRANS,
641652
num_warps=num_warps,
642653
num_stages=num_stages,
643654
)

0 commit comments

Comments
 (0)