Skip to content

Commit cb2300c

Browse files
suyogguptadominicshanshan
authored andcommitted
[None][feat] add skip condition in AutoDeploy's triton fused moe kernel (NVIDIA#8632)
Signed-off-by: Suyog Gupta <41447211+suyoggupta@users.noreply.github.com>
1 parent b0f3a44 commit cb2300c

File tree

1 file changed

+6
-0
lines changed
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe

1 file changed

+6
-0
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_moe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def fused_mlp_moe_kernel(
4141
topk_weights_ptr,
4242
sorted_token_ids_ptr,
4343
expert_ids_ptr,
44+
num_tokens_post_padded_ptr,
4445
# Matrix dimensions
4546
N,
4647
K,
@@ -84,6 +85,10 @@ def fused_mlp_moe_kernel(
8485
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
8586
pid_n = (pid % num_pid_in_group) // group_size_m
8687

88+
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
89+
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
90+
return
91+
8792
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
8893
# Bounds check: EM might not be a multiple of BLOCK_SIZE_M
8994
# so offs_token_id can exceed EM-1. Load with mask to avoid out-of-bounds.
@@ -270,6 +275,7 @@ def _grid(META):
270275
topk_weights if topk_weights is not None else C,
271276
sorted_token_ids,
272277
expert_ids,
278+
num_tokens_post_padded,
273279
B.size(1),
274280
B.size(2),
275281
EM,

0 commit comments

Comments
 (0)