Skip to content

Commit 472fe49

Browse files
authored
[None][chore] NVLinkOneSided AlltoAll Support zero local_num_tokens. (#9822)
Signed-off-by: Bo Li <[email protected]>
1 parent ea6cd76 commit 472fe49

File tree

3 files changed

+119
-70
lines changed

3 files changed

+119
-70
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 118 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
362362
int thread_idx = ThreadingPolicy::offset();
363363
int local_token_idx = ThreadingPolicy::token_idx();
364364

365-
if (local_token_idx >= local_num_tokens)
365+
if (local_num_tokens == 0)
366366
{
367-
return;
368-
}
369-
370-
// Prepare per-policy shared-memory tiles for this token
371-
extern __shared__ int smem[];
372-
int* smem_topk_target_ranks;
373-
int* smem_topk_send_indices;
374-
int warps_per_block = blockDim.x / warpSize;
375-
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
376-
{
377-
int lane_id = threadIdx.x / warpSize;
378-
smem_topk_target_ranks = smem + lane_id * TOP_K;
379-
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
367+
// Special case: If local_num_tokens == 0,
368+
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
369+
// Other threads should return.
370+
if (local_token_idx > 0)
371+
return;
380372
}
381373
else
382374
{
383-
smem_topk_target_ranks = smem;
384-
smem_topk_send_indices = smem + TOP_K;
385-
}
386-
387-
uint64_t already_copied = 0;
388-
for (int k = 0; k < TOP_K; k++)
389-
{
390-
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
391-
// Use contiguous partitioning to determine target rank
392-
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
375+
// Threads that do not have a token to process should return.
376+
if (local_token_idx >= local_num_tokens)
377+
return;
378+
379+
// Prepare per-policy shared-memory tiles for this token
380+
extern __shared__ int smem[];
381+
int* smem_topk_target_ranks;
382+
int* smem_topk_send_indices;
383+
int warps_per_block = blockDim.x / warpSize;
384+
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
385+
{
386+
int lane_id = threadIdx.x / warpSize;
387+
smem_topk_target_ranks = smem + lane_id * TOP_K;
388+
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
389+
}
390+
else
391+
{
392+
smem_topk_target_ranks = smem;
393+
smem_topk_send_indices = smem + TOP_K;
394+
}
393395

394-
if (already_copied & (1ULL << target_rank))
396+
uint64_t already_copied = 0;
397+
for (int k = 0; k < TOP_K; k++)
395398
{
399+
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
400+
// Use contiguous partitioning to determine target rank
401+
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
402+
403+
if (already_copied & (1ULL << target_rank))
404+
{
405+
if (thread_idx == 0)
406+
{
407+
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
408+
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
409+
// Mirror to shared memory immediately
410+
smem_topk_target_ranks[k] = -1;
411+
smem_topk_send_indices[k] = -1;
412+
}
413+
continue;
414+
}
415+
416+
// Only one thread per warp should increment the counter
417+
int dst_token_idx;
396418
if (thread_idx == 0)
397419
{
398-
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
399-
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
420+
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
421+
422+
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
423+
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
400424
// Mirror to shared memory immediately
401-
smem_topk_target_ranks[k] = -1;
402-
smem_topk_send_indices[k] = -1;
425+
smem_topk_target_ranks[k] = target_rank;
426+
smem_topk_send_indices[k] = dst_token_idx;
403427
}
404-
continue;
428+
already_copied |= 1ULL << target_rank;
405429
}
430+
// Sync before dispatching data
431+
ThreadingPolicy::sync();
406432

407-
// Only one thread per warp should increment the counter
408-
int dst_token_idx;
409-
if (thread_idx == 0)
433+
// Read staged routing once into registers per thread
434+
int topk_target_ranks[TOP_K];
435+
int topk_send_indices[TOP_K];
436+
#pragma unroll
437+
for (int k = 0; k < TOP_K; ++k)
410438
{
411-
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);
412-
413-
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
414-
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
415-
// Mirror to shared memory immediately
416-
smem_topk_target_ranks[k] = target_rank;
417-
smem_topk_send_indices[k] = dst_token_idx;
439+
topk_target_ranks[k] = smem_topk_target_ranks[k];
440+
topk_send_indices[k] = smem_topk_send_indices[k];
418441
}
419-
already_copied |= 1ULL << target_rank;
420-
}
421-
// Sync before dispatching data
422-
ThreadingPolicy::sync();
423442

424-
// Read staged routing once into registers per thread
425-
int topk_target_ranks[TOP_K];
426-
int topk_send_indices[TOP_K];
427-
#pragma unroll
428-
for (int k = 0; k < TOP_K; ++k)
429-
{
430-
topk_target_ranks[k] = smem_topk_target_ranks[k];
431-
topk_send_indices[k] = smem_topk_send_indices[k];
432-
}
443+
// Perform a single source load and TOP_K fanout per payload
444+
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
445+
{
446+
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
447+
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
448+
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
433449

434-
// Perform a single source load and TOP_K fanout per payload
435-
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
436-
{
437-
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
438-
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
439-
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
450+
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
451+
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
452+
}
440453

441-
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
442-
ptrs, topk_target_ranks, topk_send_indices);
454+
ThreadingPolicy::sync();
443455
}
444456

445-
ThreadingPolicy::sync();
446-
447457
bool is_first_warp = threadIdx.x / warpSize == 0;
448458
if (is_first_warp)
449459
{
@@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
452462
bool is_last_token = false;
453463
if (lane_id == 0)
454464
{
455-
int cnt = atomicAdd(ptrs.local_token_counter, 1);
456-
is_last_token = cnt + 1 == local_num_tokens;
465+
if (local_num_tokens != 0)
466+
{
467+
int cnt = atomicAdd(ptrs.local_token_counter, 1);
468+
is_last_token = cnt + 1 == local_num_tokens;
469+
}
470+
else
471+
{
472+
is_last_token = true;
473+
}
457474
}
458475
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);
459476

@@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
523540
// Validate parameters
524541
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
525542
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
526-
TLLM_CHECK(params.local_num_tokens > 0);
543+
TLLM_CHECK(params.local_num_tokens >= 0);
527544
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);
528545

529546
// Prepare kernel pointers struct
@@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
568585
if (params.one_block_per_token)
569586
{
570587
int grid_size = params.local_num_tokens;
588+
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
589+
if (grid_size == 0)
590+
{
591+
grid_size = 1;
592+
}
571593
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
572594
SWITCH_TOP_K(params.top_k, TOP_K,
573595
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
577599
else
578600
{
579601
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
602+
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
603+
if (grid_size == 0)
604+
{
605+
grid_size = 1;
606+
}
580607
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
581608
SWITCH_TOP_K(params.top_k, TOP_K,
582609
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
@@ -897,9 +924,19 @@ __global__ void moeA2ACombineKernel(
897924
int local_token_idx = ThreadingPolicy::token_idx();
898925
int const size_per_token = elements_per_token * sizeof(T);
899926

900-
if (local_token_idx >= local_num_tokens)
927+
if (local_num_tokens == 0)
901928
{
902-
return;
929+
// Special case: If local_num_tokens == 0,
930+
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
931+
// Other threads should return.
932+
if (local_token_idx > 0)
933+
return;
934+
}
935+
else
936+
{
937+
// Threads that do not have a token to process should return.
938+
if (local_token_idx >= local_num_tokens)
939+
return;
903940
}
904941

905942
#if !DISABLE_SYNC_FOR_PROFILING
@@ -951,6 +988,9 @@ __global__ void moeA2ACombineKernel(
951988
__syncthreads();
952989
#endif
953990

991+
if (local_num_tokens == 0)
992+
return;
993+
954994
// Get output location for this token (using src_data_ptrs[0] as output)
955995
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;
956996

@@ -1003,14 +1043,23 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
10031043
// Validate parameters
10041044
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
10051045
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
1006-
TLLM_CHECK(params.local_num_tokens > 0);
1046+
TLLM_CHECK(params.local_num_tokens >= 0);
10071047
TLLM_CHECK(params.elements_per_token > 0);
10081048

10091049
// Configure kernel launch
10101050
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize();
10111051
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
10121052
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
10131053
int grid_size_block = params.local_num_tokens;
1054+
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
1055+
if (grid_size_warp == 0)
1056+
{
1057+
grid_size_warp = 1;
1058+
}
1059+
if (grid_size_block == 0)
1060+
{
1061+
grid_size_block = 1;
1062+
}
10141063

10151064
// Prepare kernel pointers struct for combine
10161065
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
186186
MoeA2ADataOffsets const& offsets = *reinterpret_cast<MoeA2ADataOffsets const*>(metainfo.data_ptr<int64_t>());
187187

188188
int64_t localNumTokens = tokenSelectedExperts.size(0);
189-
TORCH_CHECK(localNumTokens > 0, "localNumTokens must be positive");
190189
TORCH_CHECK(runtimeMaxTokensPerRank > 0, "runtimeMaxTokensPerRank must be positive");
191190
TORCH_CHECK(epRank >= 0 && epRank < epSize, "epRank must be in the range [0, epSize)");
192191
TORCH_CHECK(topK > 0 && topK <= kMaxTopK, "topK must be in the range (0, kMaxTopK]");

tests/unittest/_torch/multi_gpu/test_moe_a2a.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ def test_dispatch(self, mpi_pool_executor, all_num_tokens, top_k):
566566
(4, [32, 32, 32, 32], 4),
567567
(4, [1, 1, 1, 1], 2),
568568
(8, [640, 640, 640, 640, 640, 640, 640, 640], 4),
569+
(4, [32, 0, 16, 0], 2),
569570
],
570571
indirect=["mpi_pool_executor"])
571572
def test_combine(self, mpi_pool_executor, all_num_tokens, top_k):

0 commit comments

Comments
 (0)