@@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
362362 int thread_idx = ThreadingPolicy::offset ();
363363 int local_token_idx = ThreadingPolicy::token_idx ();
364364
365- if (local_token_idx >= local_num_tokens )
365+ if (local_num_tokens == 0 )
366366 {
367- return ;
368- }
369-
370- // Prepare per-policy shared-memory tiles for this token
371- extern __shared__ int smem[];
372- int * smem_topk_target_ranks;
373- int * smem_topk_send_indices;
374- int warps_per_block = blockDim .x / warpSize ;
375- if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
376- {
377- int lane_id = threadIdx .x / warpSize ;
378- smem_topk_target_ranks = smem + lane_id * TOP_K;
379- smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
367+ // Special case: If local_num_tokens == 0,
368+ // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
369+ // Other threads should return.
370+ if (local_token_idx > 0 )
371+ return ;
380372 }
381373 else
382374 {
383- smem_topk_target_ranks = smem;
384- smem_topk_send_indices = smem + TOP_K;
385- }
386-
387- uint64_t already_copied = 0 ;
388- for (int k = 0 ; k < TOP_K; k++)
389- {
390- int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
391- // Use contiguous partitioning to determine target rank
392- int target_rank = compute_target_rank_id (expert_id, num_experts_per_rank);
375+ // Threads that do not have a token to process should return.
376+ if (local_token_idx >= local_num_tokens)
377+ return ;
378+
379+ // Prepare per-policy shared-memory tiles for this token
380+ extern __shared__ int smem[];
381+ int * smem_topk_target_ranks;
382+ int * smem_topk_send_indices;
383+ int warps_per_block = blockDim .x / warpSize ;
384+ if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
385+ {
386+ int lane_id = threadIdx .x / warpSize ;
387+ smem_topk_target_ranks = smem + lane_id * TOP_K;
388+ smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
389+ }
390+ else
391+ {
392+ smem_topk_target_ranks = smem;
393+ smem_topk_send_indices = smem + TOP_K;
394+ }
393395
394- if (already_copied & (1ULL << target_rank))
396+ uint64_t already_copied = 0 ;
397+ for (int k = 0 ; k < TOP_K; k++)
395398 {
399+ int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
400+ // Use contiguous partitioning to determine target rank
401+ int target_rank = compute_target_rank_id (expert_id, num_experts_per_rank);
402+
403+ if (already_copied & (1ULL << target_rank))
404+ {
405+ if (thread_idx == 0 )
406+ {
407+ ptrs.topk_target_ranks [local_token_idx * TOP_K + k] = -1 ;
408+ ptrs.topk_send_indices [local_token_idx * TOP_K + k] = -1 ;
409+ // Mirror to shared memory immediately
410+ smem_topk_target_ranks[k] = -1 ;
411+ smem_topk_send_indices[k] = -1 ;
412+ }
413+ continue ;
414+ }
415+
416+ // Only one thread per warp should increment the counter
417+ int dst_token_idx;
396418 if (thread_idx == 0 )
397419 {
398- ptrs.topk_target_ranks [local_token_idx * TOP_K + k] = -1 ;
399- ptrs.topk_send_indices [local_token_idx * TOP_K + k] = -1 ;
420+ dst_token_idx = atomicAdd (&ptrs.send_counters [target_rank], 1 );
421+
422+ ptrs.topk_target_ranks [local_token_idx * TOP_K + k] = target_rank;
423+ ptrs.topk_send_indices [local_token_idx * TOP_K + k] = dst_token_idx;
400424 // Mirror to shared memory immediately
401- smem_topk_target_ranks[k] = - 1 ;
402- smem_topk_send_indices[k] = - 1 ;
425+ smem_topk_target_ranks[k] = target_rank ;
426+ smem_topk_send_indices[k] = dst_token_idx ;
403427 }
404- continue ;
428+ already_copied |= 1ULL << target_rank ;
405429 }
430+ // Sync before dispatching data
431+ ThreadingPolicy::sync ();
406432
407- // Only one thread per warp should increment the counter
408- int dst_token_idx;
409- if (thread_idx == 0 )
433+ // Read staged routing once into registers per thread
434+ int topk_target_ranks[TOP_K];
435+ int topk_send_indices[TOP_K];
436+ #pragma unroll
437+ for (int k = 0 ; k < TOP_K; ++k)
410438 {
411- dst_token_idx = atomicAdd (&ptrs.send_counters [target_rank], 1 );
412-
413- ptrs.topk_target_ranks [local_token_idx * TOP_K + k] = target_rank;
414- ptrs.topk_send_indices [local_token_idx * TOP_K + k] = dst_token_idx;
415- // Mirror to shared memory immediately
416- smem_topk_target_ranks[k] = target_rank;
417- smem_topk_send_indices[k] = dst_token_idx;
439+ topk_target_ranks[k] = smem_topk_target_ranks[k];
440+ topk_send_indices[k] = smem_topk_send_indices[k];
418441 }
419- already_copied |= 1ULL << target_rank;
420- }
421- // Sync before dispatching data
422- ThreadingPolicy::sync ();
423442
424- // Read staged routing once into registers per thread
425- int topk_target_ranks[TOP_K];
426- int topk_send_indices[TOP_K];
427- #pragma unroll
428- for (int k = 0 ; k < TOP_K; ++k)
429- {
430- topk_target_ranks[k] = smem_topk_target_ranks[k];
431- topk_send_indices[k] = smem_topk_send_indices[k];
432- }
443+ // Perform a single source load and TOP_K fanout per payload
444+ for (int payload_idx = 0 ; payload_idx < num_payloads; payload_idx++)
445+ {
446+ uint8_t const * src_data = static_cast <uint8_t const *>(ptrs.src_data_ptrs [payload_idx]);
447+ int bytes_per_token = ptrs.payload_bytes_per_token [payload_idx];
448+ uint8_t const * src_ptr = src_data + local_token_idx * bytes_per_token;
433449
434- // Perform a single source load and TOP_K fanout per payload
435- for (int payload_idx = 0 ; payload_idx < num_payloads; payload_idx++)
436- {
437- uint8_t const * src_data = static_cast <uint8_t const *>(ptrs.src_data_ptrs [payload_idx]);
438- int bytes_per_token = ptrs.payload_bytes_per_token [payload_idx];
439- uint8_t const * src_ptr = src_data + local_token_idx * bytes_per_token;
450+ vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
451+ payload_idx, ptrs, topk_target_ranks, topk_send_indices);
452+ }
440453
441- vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
442- ptrs, topk_target_ranks, topk_send_indices);
454+ ThreadingPolicy::sync ();
443455 }
444456
445- ThreadingPolicy::sync ();
446-
447457 bool is_first_warp = threadIdx .x / warpSize == 0 ;
448458 if (is_first_warp)
449459 {
@@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
452462 bool is_last_token = false ;
453463 if (lane_id == 0 )
454464 {
455- int cnt = atomicAdd (ptrs.local_token_counter , 1 );
456- is_last_token = cnt + 1 == local_num_tokens;
465+ if (local_num_tokens != 0 )
466+ {
467+ int cnt = atomicAdd (ptrs.local_token_counter , 1 );
468+ is_last_token = cnt + 1 == local_num_tokens;
469+ }
470+ else
471+ {
472+ is_last_token = true ;
473+ }
457474 }
458475 is_last_token = __shfl_sync (0xffffffff , is_last_token, 0 );
459476
@@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
523540 // Validate parameters
524541 TLLM_CHECK (params.top_k > 0 && params.top_k <= kMaxTopK );
525542 TLLM_CHECK (params.ep_size > 0 && params.ep_size <= kMaxRanks );
526- TLLM_CHECK (params.local_num_tokens > 0 );
543+ TLLM_CHECK (params.local_num_tokens >= 0 );
527544 TLLM_CHECK (params.num_payloads > 0 && params.num_payloads <= kMaxPayloads );
528545
529546 // Prepare kernel pointers struct
@@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
568585 if (params.one_block_per_token )
569586 {
570587 int grid_size = params.local_num_tokens ;
588+ // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
589+ if (grid_size == 0 )
590+ {
591+ grid_size = 1 ;
592+ }
571593 int shared_bytes = 2 * params.top_k * (int ) sizeof (int );
572594 SWITCH_TOP_K (params.top_k , TOP_K,
573595 moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize , shared_bytes, params.stream>>> (
@@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
577599 else
578600 {
579601 int grid_size = ceilDiv (params.local_num_tokens , kWarpsPerBlock );
602+ // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
603+ if (grid_size == 0 )
604+ {
605+ grid_size = 1 ;
606+ }
580607 int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int ) sizeof (int );
581608 SWITCH_TOP_K (params.top_k , TOP_K,
582609 moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize , shared_bytes, params.stream>>> (
@@ -897,9 +924,19 @@ __global__ void moeA2ACombineKernel(
897924 int local_token_idx = ThreadingPolicy::token_idx ();
898925 int const size_per_token = elements_per_token * sizeof (T);
899926
900- if (local_token_idx >= local_num_tokens )
927+ if (local_num_tokens == 0 )
901928 {
902- return ;
929+ // Special case: If local_num_tokens == 0,
930+ // we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
931+ // Other threads should return.
932+ if (local_token_idx > 0 )
933+ return ;
934+ }
935+ else
936+ {
937+ // Threads that do not have a token to process should return.
938+ if (local_token_idx >= local_num_tokens)
939+ return ;
903940 }
904941
905942#if !DISABLE_SYNC_FOR_PROFILING
@@ -951,6 +988,9 @@ __global__ void moeA2ACombineKernel(
951988 __syncthreads ();
952989#endif
953990
991+ if (local_num_tokens == 0 )
992+ return ;
993+
954994 // Get output location for this token (using src_data_ptrs[0] as output)
955995 T* token_output = static_cast <T*>(ptrs.src_data_ptrs [0 ]) + local_token_idx * elements_per_token;
956996
@@ -1003,14 +1043,23 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
10031043 // Validate parameters
10041044 TLLM_CHECK (params.top_k > 0 && params.top_k <= kMaxTopK );
10051045 TLLM_CHECK (params.ep_size > 0 && params.ep_size <= kMaxRanks );
1006- TLLM_CHECK (params.local_num_tokens > 0 );
1046+ TLLM_CHECK (params.local_num_tokens >= 0 );
10071047 TLLM_CHECK (params.elements_per_token > 0 );
10081048
10091049 // Configure kernel launch
10101050 int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize ();
10111051 int const kWarpsPerBlock = kBlockSize / 32 ; // warpSize
10121052 int grid_size_warp = ceilDiv (params.local_num_tokens , kWarpsPerBlock );
10131053 int grid_size_block = params.local_num_tokens ;
1054+ // If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
1055+ if (grid_size_warp == 0 )
1056+ {
1057+ grid_size_warp = 1 ;
1058+ }
1059+ if (grid_size_block == 0 )
1060+ {
1061+ grid_size_block = 1 ;
1062+ }
10141063
10151064 // Prepare kernel pointers struct for combine
10161065 CombineKernelPointers kernel_ptrs = {}; // Zero-initialize
0 commit comments