Skip to content

Commit d1cd40b

Browse files
authored
EVEN_K flag if the K is divisible by BLOCK_SIZE_K, the masking will be disabled in GEMM
1 parent a15f90f commit d1cd40b

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

python/perf-kernels/fused_moe/moe-gemm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def moe_gemm_kernel(
3939
EM: tl.constexpr,
4040
N: tl.constexpr,
4141
K: tl.constexpr,
42+
EVEN_K: tl.constexpr,
4243
MUL_ROUTED_WEIGHT: tl.constexpr,
4344
BLOCK_SIZE_M: tl.constexpr,
4445
BLOCK_SIZE_N: tl.constexpr,
@@ -97,8 +98,12 @@ def moe_gemm_kernel(
9798

9899
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
99100
# Masking ensures we don't load from invalid tokens or indices
100-
a = tl.load(a_ptrs, mask=(token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K)), other=0.0)
101-
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), other=0.0)
101+
if EVEN_K:
102+
a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0)
103+
b = tl.load(b_ptrs)
104+
else:
105+
a = tl.load(a_ptrs, mask=(token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K)), other=0.0)
106+
b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), other=0.0)
102107

103108
accumulator = tl.dot(a, b, acc=accumulator)
104109
a_ptrs += BLOCK_SIZE_K * stride_ak
@@ -292,8 +297,10 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to
292297
_, N, K = b.shape
293298
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
294299

300+
EVEN_K = K % config["BLOCK_SIZE_K"] == 0
301+
295302
moe_gemm_kernel[grid](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1),
296-
c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, K,
303+
c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, K, EVEN_K,
297304
MUL_ROUTED_WEIGHT=topk_weights is not None, **config)
298305
return c
299306

@@ -327,6 +334,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool
327334
(64, 14336, 4096, 2, 8),
328335
(16, 14336, 1, 2, 4),
329336
(1, 14336, 128, 2, 4),
337+
(3, 14336, 128, 2, 4),
330338
(16, 14336, 128, 1, 4),
331339
(16, 14336, 128, 1, 1),
332340
(64, 7186, 128, 2, 8),

0 commit comments

Comments
 (0)