Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
64bce0f
re-implement micro batch scheduler and capacity scheduler in python
QiJune Dec 17, 2025
034fffb
refine
QiJune Dec 17, 2025
927b417
enable SimpleUnifiedScheduler
QiJune Dec 17, 2025
3609b20
fix
QiJune Dec 17, 2025
c901b21
fix
QiJune Dec 17, 2025
84cebc9
fix
QiJune Dec 17, 2025
490f8e9
fix
QiJune Dec 17, 2025
4e62403
fix
QiJune Dec 17, 2025
87caccb
fix
QiJune Dec 17, 2025
d1aebe7
fix
QiJune Dec 17, 2025
4d1f530
fix
QiJune Dec 17, 2025
641236d
fix
QiJune Dec 17, 2025
162d59e
enable py scheduler
QiJune Dec 17, 2025
707fb4a
support bert
QiJune Dec 17, 2025
fbc8486
fix
QiJune Dec 18, 2025
d344670
fix
QiJune Dec 18, 2025
6617a47
fix
QiJune Dec 18, 2025
63c09c6
fix
QiJune Dec 18, 2025
c2bffa5
fix
QiJune Dec 18, 2025
2a3a7f2
fix gemma
QiJune Dec 19, 2025
2f30b99
fix lora
QiJune Dec 19, 2025
066b653
[TRTLLM-9880][feat] Include torch compile tests in QA test list (#10149)
liji-nv Dec 22, 2025
f0bd60a
[https://nvbugs/5684820][fix] fix the detokenizer issue for DeepSeek-…
lfr-0531 Dec 22, 2025
f8501f3
[None][infra] Check in most recent lock file from nightly pipeline
tensorrt-cicd Dec 22, 2025
237fd0e
[https://nvbugs/5666821][chore] unwaive tests. (#9958)
yuxianq Dec 22, 2025
d30ee81
[None][chore] Remove closed bugs (#10182)
xinhe-nv Dec 22, 2025
7421224
[None][fix] NVFP4 linear method's weight and weight_scale padding (#1…
JadoTu Dec 22, 2025
9e9523c
[https://nvbugs/5762016][chore] Skip a ray test (#10194)
shuyixiong Dec 22, 2025
c87f1a6
[https://nvbugs/5503479][fix] update trtllm-gen kernels to address fe…
PerkzZheng Dec 22, 2025
ea6cd76
[None][refactor] simplify get_stats and get_kvcache_events with rpc (…
Superjomn Dec 22, 2025
472fe49
[None][chore] NVLinkOneSided AlltoAll Support zero local_num_tokens. …
bobboli Dec 22, 2025
a6a8898
[TRTLLM-9409][feat] Pass MRoPE tensors for EPD disagg (#9758)
2ez4bz Dec 22, 2025
0f308e9
[None][chore] Remove logprobs constraint on trtllm-serve pytorch back…
LinPoly Dec 22, 2025
ba14a93
[None][infra] Waive failed cases on 12/22 (#10200)
EmmaQiaoCh Dec 22, 2025
aaa87ab
[TRTLLM-7906][feat] Support multiple post process for Responses API (…
JunyiXu-nv Dec 22, 2025
12e1cb8
[#9717][chore] Refactor MoE code to use enums (#9910)
tcherckez-nvidia Dec 22, 2025
ccc64da
[TRTLLM-9847][fix] WAR fix hanging fused allreduce. (#10087)
greg-kwasniewski1 Dec 22, 2025
0d2500c
[TRTLLM-9677][feat] Support DeepSeek-V3.2 tool parser (#10126)
lfr-0531 Dec 23, 2025
f05af48
[https://nvbugs/5747674][fix] Add contiguous() before view() in load_…
farazkh80 Dec 23, 2025
648196f
[TRTLLM-9432][feat] Reduce synchronization and recompilation for qwen…
yuantailing Dec 23, 2025
696f754
[None][fix] avoid implicit cudaStreamSynchronize in sample_async. (#1…
yuxianq Dec 23, 2025
1e82ff7
[TRTLLM-9989][fix] Fix tvm_ffi aaarch64 issue. (#10199)
limin2021 Dec 23, 2025
621156a
[None][chore] Fix GB300 support issues (#10196)
fredricz-20070104 Dec 23, 2025
18f8b22
[None][infra] Check in most recent lock file from nightly pipeline
tensorrt-cicd Dec 23, 2025
5bc7ffe
[None][test] Add qa tests for RTX 6K (#10210)
pamelap-nvidia Dec 23, 2025
d691371
[TRTLLM-9091] [feat] Replace GenAI-Perf with AIPerf (#9310)
lkomali Dec 23, 2025
77b591f
[None][chore] Add failed cases into waives.txt (#10177)
xinhe-nv Dec 23, 2025
53db3b2
[https://nvbugs/5741884][fix] unwaive disagg sampler (#10189)
chuangz0 Dec 23, 2025
59b05dc
[None][chore] Bump version to 1.2.0rc7 (#10216)
yiqingy0 Dec 23, 2025
cc1323b
[None][fix] Fix the bug for top_k=10 in NVLinkOneSided AlltoAll. (#10…
bobboli Dec 23, 2025
9fce092
Merge branch 'main' into py_scheduler
lancelly Dec 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads)
[![torch](https://img.shields.io/badge/torch-2.9.0-green)](https://pytorch.org)
[![version](https://img.shields.io/badge/release-1.2.0rc6-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-1.2.0rc7-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)

[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](https://nvidia.github.io/TensorRT-LLM/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
Expand Down
192 changes: 120 additions & 72 deletions cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -362,88 +362,98 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
int thread_idx = ThreadingPolicy::offset();
int local_token_idx = ThreadingPolicy::token_idx();

if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
}

// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}

uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;

// Prepare per-policy shared-memory tiles for this token
extern __shared__ int smem[];
int* smem_topk_target_ranks;
int* smem_topk_send_indices;
int warps_per_block = blockDim.x / warpSize;
if constexpr (std::is_same<ThreadingPolicy, WarpPolicy>::value)
{
int lane_id = threadIdx.x / warpSize;
smem_topk_target_ranks = smem + lane_id * TOP_K;
smem_topk_send_indices = smem + warps_per_block * TOP_K + lane_id * TOP_K;
}
else
{
smem_topk_target_ranks = smem;
smem_topk_send_indices = smem + TOP_K;
}

if (already_copied & (1ULL << target_rank))
uint64_t already_copied = 0;
for (int k = 0; k < TOP_K; k++)
{
int expert_id = token_selected_experts[local_token_idx * TOP_K + k];
// Use contiguous partitioning to determine target rank
int target_rank = compute_target_rank_id(expert_id, num_experts_per_rank);

if (already_copied & (1ULL << target_rank))
{
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
}
continue;
}

// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
{
ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = -1;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = -1;
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);

ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = -1;
smem_topk_send_indices[k] = -1;
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
}
continue;
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();

// Only one thread per warp should increment the counter
int dst_token_idx;
if (thread_idx == 0)
// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
#pragma unroll
for (int k = 0; k < TOP_K; ++k)
{
dst_token_idx = atomicAdd(&ptrs.send_counters[target_rank], 1);

ptrs.topk_target_ranks[local_token_idx * TOP_K + k] = target_rank;
ptrs.topk_send_indices[local_token_idx * TOP_K + k] = dst_token_idx;
// Mirror to shared memory immediately
smem_topk_target_ranks[k] = target_rank;
smem_topk_send_indices[k] = dst_token_idx;
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
}
already_copied |= 1ULL << target_rank;
}
// Sync before dispatching data
ThreadingPolicy::sync();

// Read staged routing once into registers per thread
int topk_target_ranks[TOP_K];
int topk_send_indices[TOP_K];
#pragma unroll
for (int k = 0; k < TOP_K; ++k)
{
topk_target_ranks[k] = smem_topk_target_ranks[k];
topk_send_indices[k] = smem_topk_send_indices[k];
}
// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;

// Perform a single source load and TOP_K fanout per payload
for (int payload_idx = 0; payload_idx < num_payloads; payload_idx++)
{
uint8_t const* src_data = static_cast<uint8_t const*>(ptrs.src_data_ptrs[payload_idx]);
int bytes_per_token = ptrs.payload_bytes_per_token[payload_idx];
uint8_t const* src_ptr = src_data + local_token_idx * bytes_per_token;
vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank,
payload_idx, ptrs, topk_target_ranks, topk_send_indices);
}

vectorized_dispatch<TOP_K, ThreadingPolicy>(src_ptr, bytes_per_token, rank_id, max_tokens_per_rank, payload_idx,
ptrs, topk_target_ranks, topk_send_indices);
ThreadingPolicy::sync();
}

ThreadingPolicy::sync();

bool is_first_warp = threadIdx.x / warpSize == 0;
if (is_first_warp)
{
Expand All @@ -452,8 +462,15 @@ __global__ void moeA2ADispatchKernel(int32_t const* token_selected_experts, // [
bool is_last_token = false;
if (lane_id == 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
if (local_num_tokens != 0)
{
int cnt = atomicAdd(ptrs.local_token_counter, 1);
is_last_token = cnt + 1 == local_num_tokens;
}
else
{
is_last_token = true;
}
}
is_last_token = __shfl_sync(0xffffffff, is_last_token, 0);

Expand Down Expand Up @@ -523,7 +540,7 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.num_payloads > 0 && params.num_payloads <= kMaxPayloads);

// Prepare kernel pointers struct
Expand Down Expand Up @@ -568,6 +585,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
if (params.one_block_per_token)
{
int grid_size = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<BlockPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
Expand All @@ -577,6 +599,11 @@ void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params)
else
{
int grid_size = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size == 0)
{
grid_size = 1;
}
int shared_bytes = 2 * kWarpsPerBlock * params.top_k * (int) sizeof(int);
SWITCH_TOP_K(params.top_k, TOP_K,
moeA2ADispatchKernel<WarpPolicy, TOP_K><<<grid_size, kBlockSize, shared_bytes, params.stream>>>(
Expand Down Expand Up @@ -626,6 +653,7 @@ __device__ void vectorized_combine_impl(
// Load directly into the per-k accumulator; reduce across k below
acc[k].load(recv_buffer + base_token + offset);
}
// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 16)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
Expand Down Expand Up @@ -710,9 +738,7 @@ __device__ void vectorized_combine_impl(
a0[j] += a8[j];
}
}

// Reduce acc[TOP_K] into acc[0]
if constexpr (TOP_K == 8)
else if constexpr (TOP_K == 8)
{
T* a0 = reinterpret_cast<T*>(&acc[0]);
T* a1 = reinterpret_cast<T*>(&acc[1]);
Expand Down Expand Up @@ -897,9 +923,19 @@ __global__ void moeA2ACombineKernel(
int local_token_idx = ThreadingPolicy::token_idx();
int const size_per_token = elements_per_token * sizeof(T);

if (local_token_idx >= local_num_tokens)
if (local_num_tokens == 0)
{
return;
// Special case: If local_num_tokens == 0,
// we need to keep the threads where local_token_idx == 0 alive to participate in the synchronization.
// Other threads should return.
if (local_token_idx > 0)
return;
}
else
{
// Threads that do not have a token to process should return.
if (local_token_idx >= local_num_tokens)
return;
}

#if !DISABLE_SYNC_FOR_PROFILING
Expand Down Expand Up @@ -951,6 +987,9 @@ __global__ void moeA2ACombineKernel(
__syncthreads();
#endif

if (local_num_tokens == 0)
return;

// Get output location for this token (using src_data_ptrs[0] as output)
T* token_output = static_cast<T*>(ptrs.src_data_ptrs[0]) + local_token_idx * elements_per_token;

Expand Down Expand Up @@ -1003,14 +1042,23 @@ void moe_a2a_combine_launch(MoeA2ACombineParams const& params)
// Validate parameters
TLLM_CHECK(params.top_k > 0 && params.top_k <= kMaxTopK);
TLLM_CHECK(params.ep_size > 0 && params.ep_size <= kMaxRanks);
TLLM_CHECK(params.local_num_tokens > 0);
TLLM_CHECK(params.local_num_tokens >= 0);
TLLM_CHECK(params.elements_per_token > 0);

// Configure kernel launch
int const kBlockSize = tensorrt_llm::common::getEnvMoeA2ACombineBlockSize();
int const kWarpsPerBlock = kBlockSize / 32; // warpSize
int grid_size_warp = ceilDiv(params.local_num_tokens, kWarpsPerBlock);
int grid_size_block = params.local_num_tokens;
// If local_num_tokens is 0, we still need to launch a minimal kernel to participate in the synchronization.
if (grid_size_warp == 0)
{
grid_size_warp = 1;
}
if (grid_size_block == 0)
{
grid_size_block = 1;
}

// Prepare kernel pointers struct for combine
CombineKernelPointers kernel_ptrs = {}; // Zero-initialize
Expand Down
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Loading
Loading