Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3f7cede
Update transformers to 4.53.0 (#5747)
Wanli-Jiang Jul 9, 2025
b61a717
[1/N][TRTLLM-5195][feat] Share PyTorch tensor between processes (#5396)
chang-l Jul 9, 2025
87fe44f
feat(models): Mistral3.1 VLM pytorch backend support (#5529)
2ez4bz Jul 9, 2025
3209b31
feat: Custom masking utils for Gemma3 VLM (#5853)
brb-nv Jul 9, 2025
76c3a12
[fix] WAR to fix the illegal memory access issue in moe gemm on SM120…
peaceh-nv Jul 10, 2025
f57b3d6
Waive unittest failures introduced by PR#5345 (removal of `Scaffoldin…
venkywonka Jul 10, 2025
6490a27
[feat] Add TensorRT-Engine Qwen3 (dense) model support (#5650)
gkswns0531 Jul 10, 2025
07f6da7
[TRTLLM-5530] chore: rename LLM.autotuner_enabled to enable_autotuner…
Superjomn Jul 10, 2025
e289a98
avoid nesting NCCL group in allgather and reduce scatter OPs (#5866)
QiJune Jul 10, 2025
3ec3ff1
chore: remove support for llmapi + TRT backend in Triton (#5856)
achartier Jul 10, 2025
7d21b55
[feat] Add TRTLLM MoE nvfp4 cubins for mid-high concurrency; attentio…
rosenrodt Jul 10, 2025
dc32f9a
[fix] fix tileN cannot % 16==0 & support sm89 deepgemm bmm (#5531)
CarstyYou Jul 10, 2025
055c4a9
[NvBug 5370718, 5371538] fix: Fix incremental detokenization (#5825)
syuoni Jul 10, 2025
3aa53ec
[None] - Waive L0 tests (#5915)
yiqingy0 Jul 10, 2025
8b9a030
[fix] Fix MoE workspace info by storing Torch tensor itself instead o…
jinyangyuan-nvidia Jul 10, 2025
7b09a41
fix: Make the bench serving script compatible with different usages (…
kaiyux Jul 10, 2025
41ef1ad
feat:enable kvcache to be reused during request generation (#4028)
narutolhy Jul 10, 2025
67a39db
infra: [TRTLLM-6054][TRTLLM-5804] Fix two known NSPECT high vulnerabi…
ZhanruiSunCh Jul 10, 2025
2e3cf42
[refactor] Simplification of Speculative decoding configs (#5639)
wili-65535 Jul 10, 2025
4d071eb
feat: binding type build argument (pybind, nanobind) (#5802)
Linda-Stadter Jul 10, 2025
c32c9e2
doc: Add instructions for running gemma in disaggregated serving (#5922)
Tabrizian Jul 10, 2025
c198402
[fix] Fix mistral unit tests due to transformers upgrade (#5904)
2ez4bz Jul 10, 2025
682acd4
[nvbugs/5321981] Cherrypick fix: Fix the Llama3.1 405B hanging issue.…
nvzhihanj Jul 10, 2025
aa4eebe
[enhance] Add the ability to write a request timeline. (#5258)
FrankD412 Jul 11, 2025
854655f
deepEP fp4 post quant all2all dispatch (#5881)
yilin-void Jul 11, 2025
0385f89
test: Fix Gemma3 unit tests due to transformers upgrade (#5921)
brb-nv Jul 11, 2025
fbb4cc7
[TRTLLM-4770][feat] Enhance cpp executor cmake to listen to ENABLE_MU…
WilliamTambellini Jul 11, 2025
37293e4
blog: add qwen3 disagg perf metrics (#5822)
Shixiaowei02 Jul 11, 2025
c5fb692
Refactor the rest routing part for the routing kernels in the MoE TRT…
ChristinaZ Jul 11, 2025
4935957
[TRTLLM-5673] Doc: ensure the disagg doc is up to date (#5938)
Shixiaowei02 Jul 11, 2025
f4e0425
doc: update the link of the diagram (#5953)
Shixiaowei02 Jul 11, 2025
509363d
tests: update sanity tests & fix tests (#5906)
xinhe-nv Jul 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@
[submodule "3rdparty/xgrammar"]
path = 3rdparty/xgrammar
url = https://github.com/mlc-ai/xgrammar.git
[submodule "3rdparty/nanobind"]
path = 3rdparty/nanobind
url = https://github.com/wjakob/nanobind
1 change: 1 addition & 0 deletions 3rdparty/nanobind
Submodule nanobind added at a0ed25
4 changes: 4 additions & 0 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ h11>=0.16.0
tornado>=6.5.0
# WAR against https://github.com/advisories/GHSA-5rjg-fvgr-3xxf
setuptools>=78.1.1
# WAR against https://github.com/advisories/GHSA-8qvm-5x2c-j2w7
protobuf>=4.25.8
# WAR against https://github.com/advisories/GHSA-33p9-3p43-82vq
jupyter-core>=5.8.1
25 changes: 20 additions & 5 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ project(tensorrt_llm LANGUAGES CXX)

# Build options
option(BUILD_PYT "Build in PyTorch TorchScript class mode" ON)
option(BUILD_PYBIND "Build Python bindings for C++ runtime and batch manager"
ON)
option(BUILD_TESTS "Build Google tests" ON)
option(BUILD_BENCHMARKS "Build benchmarks" ON)
option(BUILD_MICRO_BENCHMARKS "Build C++ micro benchmarks" OFF)
Expand Down Expand Up @@ -68,6 +66,11 @@ endif()
add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
add_compile_definitions("TLLM_ENABLE_CUDA")

set(BINDING_TYPE
"pybind"
CACHE STRING
"Binding type of Python bindings for C++ runtime and batch manager")

set(INTERNAL_CUTLASS_KERNELS_PATH
""
CACHE
Expand Down Expand Up @@ -195,7 +198,14 @@ set(TRT_LIB TensorRT::NvInfer)
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)

set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11)
if(BINDING_TYPE STREQUAL "pybind")
add_subdirectory(${3RDPARTY_DIR}/pybind11
${CMAKE_CURRENT_BINARY_DIR}/pybind11)
endif()
if(BINDING_TYPE STREQUAL "nanobind")
add_subdirectory(${3RDPARTY_DIR}/nanobind
${CMAKE_CURRENT_BINARY_DIR}/nanobind)
endif()

# include as system to suppress warnings
include_directories(
Expand All @@ -206,8 +216,13 @@ include_directories(
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/NVTX/include
${3RDPARTY_DIR}/json/include
${3RDPARTY_DIR}/pybind11/include)
${3RDPARTY_DIR}/json/include)
if(BINDING_TYPE STREQUAL "pybind")
include_directories(${3RDPARTY_DIR}/pybind11/include)
endif()
if(BINDING_TYPE STREQUAL "nanobind")
include_directories(${3RDPARTY_DIR}/nanobind/include)
endif()

if(${CUDAToolkit_VERSION} VERSION_GREATER_EQUAL "11")
add_definitions("-DENABLE_BF16")
Expand Down
12 changes: 12 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ class WindowBlockManager

void storeBlocksForReuse(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Release blocks of the sequence.
void releaseBlocks(GenerationRequest& sequence);

Expand Down Expand Up @@ -1092,6 +1094,9 @@ class BlockManager
//! \brief Store context blocks
void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest);

//! \brief Store newest block for reuse
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

[[nodiscard]] static bool isUseOneMoreBlock(
SizeType32 windowSize, std::optional<SizeType32> maxSequenceLength, SizeType32 maxBeamWidth)
{
Expand Down Expand Up @@ -1262,6 +1267,10 @@ class BaseKVCacheManager
//! \details These blocks become reusable from next step.
virtual void storeContextBlocks(LlmRequest const& llmRequest) = 0;

//! \brief Store newest block for reuse.
//! \details This block become reusable from next step.
virtual void storeNewBlock(LlmRequest const& llmRequest) = 0;

//! \brief Get the block ids of a request [per beam] **for a given window size block manager**
[[nodiscard]] virtual std::vector<std::vector<SizeType32>> const& getCacheBlockIds(
LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
Expand Down Expand Up @@ -1568,6 +1577,9 @@ class KVCacheManager : public BaseKVCacheManager
//! \details These blocks become reusable from next step.
void storeContextBlocks(LlmRequest const& llmRequest) override;

//! \brief Store newest blocks for reuse
void storeNewBlock(LlmRequest const& llmRequest) override;

[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);

[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 inputLength, SizeType32 outputLength) const override;
Expand Down
6 changes: 5 additions & 1 deletion cpp/tensorrt_llm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,14 @@ if(BUILD_PYT)
add_subdirectory(thop)
endif()

if(BUILD_PYBIND)
if(BINDING_TYPE STREQUAL "pybind")
add_subdirectory(pybind)
endif()

if(BINDING_TYPE STREQUAL "nanobind")
add_subdirectory(nanobind)
endif()

if(BUILD_DEEP_EP)
add_subdirectory(deep_ep)
endif()
Expand Down
82 changes: 82 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,77 @@ void BlockManager::releaseBlocks(GenerationRequest& sequence, OptionalRef<LlmReq
}
}

void BlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
{
// we store newest block for potential reuse only if:
// - Block reuse is enabled.
// - A request was provided to this function call to identify which tokens these blocks cover
// - Beam search is NOT enabled <=> beam width == 1
// - The sequence was not marked for use with cyclic kv-cache when it was added (when its context is too long to fit
// the max attention window).
// - The sequence did not switch to cyclic kv-cache during generation phase.
// A sequence is cyclic if its *minimum window size* is crossed, even if other window sizes were not reached.
bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && llmRequest.has_value() && !sequence.isCyclic();
if (!storeBlocksForReuse)
{
return;
}
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.storeNewBlock(sequence, llmRequest);
}
}

void WindowBlockManager::storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
{
auto constexpr beamIdx = 0;
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize);

if (uniqueTokens.size() == 0)
{
return;
}

// TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
// have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
// the last token's state is not filled yet.
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
if (usableSize % mTokensPerBlock != 0)
{
return;
}
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);
if (blockKeys.size() < 2 || cacheBlockIds[beamIdx].size() < blockKeys.size())
{
// store all blocks
TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str());
storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
return;
}

auto lastBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 1]);
auto prevBlock = mAllBlocksById.at(cacheBlockIds[beamIdx][blockKeys.size() - 2]);

// If the previous block is not in the radix tree, we need to store all blocks
if (prevBlock->getPrevBlock() == nullptr)
{
TLLM_LOG_DEBUG("%s::storeNewBlock - store all blocks", mLogPrefix.c_str());
storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
return;
}

if (lastBlock->getPrevBlock() != nullptr)
{
// If the last block is not in the radix tree, we need to store all blocks
TLLM_LOG_DEBUG("%s::storeNewBlock - no need to store", mLogPrefix.c_str());
return;
}
TLLM_LOG_DEBUG("%s::storeNewBlock - store the last block", mLogPrefix.c_str());
storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}

void WindowBlockManager::storeBlocksForReuse(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest)
{
auto constexpr beamIdx = 0;
Expand Down Expand Up @@ -1964,6 +2035,17 @@ void KVCacheManager::storeContextBlocks(LlmRequest const& llmRequest)
}
}

void KVCacheManager::storeNewBlock(LlmRequest const& llmRequest)
{
auto const requestId = llmRequest.mRequestId;
auto& sequence = getSequence(requestId);
bool const storeBlocksForReuse = sequence.getBeamWidth() == 1 && !sequence.isCyclic();
if (mEnableBlockReuse && storeBlocksForReuse)
{
mBlockManager.storeNewBlock(sequence, llmRequest);
}
}

void KVCacheManager::removeSequence(RequestIdType requestId, OptionalRef<LlmRequest const> llmRequest)
{
TLLM_LOG_TRACE("[%s]::%s start", isCrossKv() ? "CROSS" : "SELF", __PRETTY_FUNCTION__);
Expand Down
20 changes: 20 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,25 @@ void TrtGptModelInflightBatching::storeContextBlocks(std::shared_ptr<LlmRequest>
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void TrtGptModelInflightBatching::storeNewBlock(std::shared_ptr<LlmRequest> const& llmReq)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

// TMJ - Note
// Make context blocks reusable immediately after each generation step.

if (mKvCacheManager)
{
mKvCacheManager->storeNewBlock(*llmReq);
}
if (mCrossKvCacheManager)
{
mCrossKvCacheManager->storeNewBlock(*llmReq);
}

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void TrtGptModelInflightBatching::resetIterationStats()
{
mLastIterationStatsIFB = IterationStatsIFB{mMicroBatchId};
Expand Down Expand Up @@ -1099,6 +1118,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
}
else if (llmReq->isGenerationInProgressState())
{
storeNewBlock(llmReq);
TLLM_LOG_DEBUG("request with ID %lu forwards a step in decoder gen phase", llmReq->mRequestId);
}
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ class TrtGptModelInflightBatching : public TrtGptModel
//! These blocks become reusable from next step.
void storeContextBlocks(std::shared_ptr<LlmRequest> const& req);

//! @brief Store newest kv cache block for reuse.
//! The block become reusable from next step.
void storeNewBlock(std::shared_ptr<LlmRequest> const& req);

//! @brief Set LayerProfiler to collect performance per layer.
void setLayerProfiler() override;

Expand Down
Loading