Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def fused_mlp_moe_kernel(
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
Expand Down Expand Up @@ -84,6 +85,10 @@ def fused_mlp_moe_kernel(
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return

offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
# Bounds check: EM might not be a multiple of BLOCK_SIZE_M
# so offs_token_id can exceed EM-1. Load with mask to avoid out-of-bounds.
Expand Down Expand Up @@ -270,6 +275,7 @@ def _grid(META):
topk_weights if topk_weights is not None else C,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.size(1),
B.size(2),
EM,
Expand Down