File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -41,6 +41,7 @@ def fused_mlp_moe_kernel(
4141 topk_weights_ptr ,
4242 sorted_token_ids_ptr ,
4343 expert_ids_ptr ,
44+ num_tokens_post_padded_ptr ,
4445 # Matrix dimensions
4546 N ,
4647 K ,
@@ -84,6 +85,10 @@ def fused_mlp_moe_kernel(
8485 pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
8586 pid_n = (pid % num_pid_in_group ) // group_size_m
8687
88+ num_tokens_post_padded = tl .load (num_tokens_post_padded_ptr )
89+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded :
90+ return
91+
8792 offs_token_id = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M ).to (tl .int64 )
8893 # Bounds check: EM might not be a multiple of BLOCK_SIZE_M
8994 # so offs_token_id can exceed EM-1. Load with mask to avoid out-of-bounds.
@@ -270,6 +275,7 @@ def _grid(META):
270275 topk_weights if topk_weights is not None else C ,
271276 sorted_token_ids ,
272277 expert_ids ,
278+ num_tokens_post_padded ,
273279 B .size (1 ),
274280 B .size (2 ),
275281 EM ,
You can’t perform that action at this time.
0 commit comments