From 1d484eb05ac914aeab8812c721b32a503c241c7f Mon Sep 17 00:00:00 2001 From: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Date: Mon, 3 Nov 2025 21:14:32 +0000 Subject: [PATCH 1/9] Add refresh blocks Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> . --- cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp | 3 ++- cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp | 3 ++- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index c3bccf87b47..0949a3963f1 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -465,7 +465,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, - nb::call_guard()); + nb::call_guard()) + .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard()); nb::bind_vector(m, "CacheBlockIds") .def("__getstate__", [](CacheBlockIds const& v) { return nb::make_tuple(v); }) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 320659a1d09..88fbaf3a195 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -467,7 +467,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def("get_newly_allocated_block_ids", &BaseKVCacheManager::getNewlyAllocatedBlockIds, py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, - py::call_guard()); + py::call_guard()) + .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard()); py::enum_(m, "CacheType") .value("SELF", tbk::CacheType::kSELF) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 6298db0146d..be9396e0d2c 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -441,6 +441,8 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): self.kv_connector_manager.update_state_after_alloc( req, block_ids) + self.impl.refresh_blocks() + for req in generation_batch: self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): From 307479babf92c5388b23ad1b9d5417167f71344a Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Thu, 6 Nov 2025 19:45:27 +0000 Subject: [PATCH 2/9] Fix transfer manager synchronization issues Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> . --- .../batch_manager/kvCacheManager.h | 12 ++ .../batch_manager/kvCacheTransferManager.h | 12 +- .../batch_manager/allocateKvCache.cpp | 2 + .../batch_manager/kvCacheManager.cpp | 13 ++ .../batch_manager/kvCacheTransferManager.cpp | 119 ++++++++++++++++-- .../nanobind/batch_manager/kvCacheManager.cpp | 6 + .../pybind/batch_manager/kvCacheManager.cpp | 6 + .../_torch/pyexecutor/resource_manager.py | 2 +- 8 files changed, 155 insertions(+), 17 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index cc227e75ca1..be94e13740a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -737,6 +737,9 @@ class WindowBlockManager return mBufferManager; } + //! \brief Sync internal streams used by transfer manager with buffer manager stream + void syncTransferManagerWithBufferManager(); + //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -1136,6 +1139,9 @@ class BlockManager return isCyclicWindowSize && isBeamSearch; } + //! \brief Sync internal streams used by transfer manager with buffer manager stream + void syncTransferManagerWithBufferManager(); + //! \brief Perform per-request bookkeeping void refreshBlocks(); @@ -1332,6 +1338,7 @@ class BaseKVCacheManager [[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 layer_idx) const = 0; [[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0; + virtual void syncTransferManagerWithBufferManager() = 0; virtual void refreshBlocks() = 0; virtual void flushIterationEvents() = 0; @@ -1677,6 +1684,11 @@ class KVCacheManager : public BaseKVCacheManager return mBlockManager.getPoolLayerIdx(layer_idx); } + void syncTransferManagerWithBufferManager() override + { + mBlockManager.syncTransferManagerWithBufferManager(); + } + //! \brief Perform per-iteration bookkeeping void refreshBlocks() override { diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h index 45f615cafe7..6d52e166dd2 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h @@ -46,7 +46,11 @@ class KVCacheTransferManager int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const& directory = ""); - //! \brief Synchronize the offload/onboard streams with the bufferManager stream. + //! \brief Synchronize internal streams with bufferManager stream. + //! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step. + void syncWithBufferManager(); + + //! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method must be called after last call to KVCacheManager::addSequence in every step. void syncTransfers(); private: @@ -75,8 +79,10 @@ class KVCacheTransferManager runtime::BufferManager mOnboardManager; runtime::BufferManager mOffloadManager; - // Track the block ids offloaded in this iteration. - std::unordered_map mPendingOffloads; + // Track reads and writes for blocks. Note that it is the memory pool index that + // identifies the raw memory blocks involved in I/O, not the block Id. + std::unordered_map mPendingReads; + std::unordered_map mPendingWrites; // Reference to parent loopback agent std::shared_ptr mLoopbackAgent; int mDeviceId; diff --git a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp index c0482deb554..211abe78186 100644 --- a/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp +++ b/cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp @@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); NVTX3_SCOPED_RANGE(allocateKvCache); + kvCacheManager.syncTransferManagerWithBufferManager(); + for (auto const& llmReq : contextRequests) { if (llmReq->isFirstContextChunk()) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 81a4746467d..430f40f12b2 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1170,6 +1170,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& return numMatchedTokens; } + +void BlockManager::syncTransferManagerWithBufferManager() +{ + for (auto& [_, manager] : mWindowBlockManagers) + { + manager.syncTransferManagerWithBufferManager(); + } +} + +void WindowBlockManager::syncTransferManagerWithBufferManager() +{ + mTransferManager->syncWithBufferManager(); +} void BlockManager::refreshBlocks() { diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index 35868b35ac1..bb3421f255c 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -219,47 +219,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, } } +// +// Note about recording events to wait for cudaMempyAsync calls between blocks: +// The memory copy involves raw memory blocks, which are pointed to by the +// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex() +// as the raw memory block identifier. Earlier versions of this code used getBlockId() +// when recording events, this is just wrong. getBlockId() returns the logical block id, +// which has nothing to do with the raw memory block pointers involved in a cudaMemcpy. +// + +// +// Notes about need for synchronization: +// +// Earlier versions of this code relied on decoder implicitly syncing GPU with CPU. +// This is inherently dangerous, it is not given that decoder will always explicitly sync +// GPU with CPU for every step, a major design goal of ongoing work is to avoid this. +// To make the code future proof, we introdduce a new method SyncWithBufferManager() +// that ensures that internal copy streams will wait for prefill and decode kernels +// that have already been scheduled. +// +// Earlier versions of this code did not account for all possible cases were a new block copy +// needed to wait for a previously scheduled copy to finish. For instance, it is possible +// that two primary blocks are offloaded to the same secondary block in a single step, +// scheduling the second offloading without waiting for the first one to finish leads to +// a corrupted block after offloading. It is possible that partial reuse will copy +// from a block that is currently being onboarded, scheduling the partial copy without +// waiting for the onboarding to finish will lead to a corrupted block. To handle all +// possible cases needing synchronization we record separate events for reads and writes +// to a block. When a new block copy is scheduled, we wait for all writes to the source +// block and all reads and writes to a destination block. +// +// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence. +// Failing to do so will lead to corrupted blocks eventually. +// + void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - if (mode != executor::KvCacheTransferMode::DRAM - && mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end()) + // Wait for any pending writes before reading from offloadBlock + auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingWriteItr != mPendingWrites.end()) { - TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk", - offloadBlock->getBlockId()); - return; + mOnboardManager.getStream().wait(offloadBlockPendingWriteItr->second); + // Don't erase, we are not changing state of offloadBlock } - - if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end()) + // Wait for any pending reads before overwriting block + auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex()); + if (blockPendingReadItr != mPendingReads.end()) { - mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]); + mOnboardManager.getStream().wait(blockPendingReadItr->second); + mPendingReads.erase(blockPendingReadItr); } + // Wait for any pending writes before overwriting block + auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); + if (blockPendingWriteItr != mPendingWrites.end()) + { + mOnboardManager.getStream().wait(blockPendingWriteItr->second); + mPendingWrites.erase(blockPendingWriteItr); + } + copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory); + + // Record new pending read from offloadBlock + mPendingReads[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingReads[offloadBlock->getMemoryPoolBlockIndex()]); + // Record new pending write to block + mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]); } void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - mPendingOffloads[block->getBlockId()] = tr::CudaEvent(); + // Wait for any pending writes before reading from block + auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex()); + if (blockPendingWriteItr != mPendingWrites.end()) + { + mOffloadManager.getStream().wait(blockPendingWriteItr->second); + // Don't erase, we are not changing state of block + } + // Wait for any pending reads before overwriting offloadBlock + auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingReadItr != mPendingReads.end()) + { + mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second); + mPendingReads.erase(offloadBlockPendingReadItr); + } + // Wait for any pending writes before overwriting offloadBlock + auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); + if (offloadBlockPendingWriteItr != mPendingWrites.end()) + { + mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second); + mPendingWrites.erase(offloadBlockPendingWriteItr); + } + copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory); - mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]); + + // Record new pending read from block + mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]); + // Record new pending write to offloadBlock + mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]); +} + +void KVCacheTransferManager::syncWithBufferManager() +{ + tr::CudaEvent readyForOffloadEvent; + mBufferManager.getStream().record(readyForOffloadEvent); + mOffloadManager.getStream().wait(readyForOffloadEvent); + + tr::CudaEvent readyForOnboardEvent; + mBufferManager.getStream().record(readyForOnboardEvent); + mOnboardManager.getStream().wait(readyForOnboardEvent); + + // Once we synchronize, clear our list of pending thransfers. + mPendingReads.clear(); + mPendingWrites.clear(); } void KVCacheTransferManager::syncTransfers() { tr::CudaEvent offloadEvent; mOffloadManager.getStream().record(offloadEvent); + mBufferManager.getStream().wait(offloadEvent); tr::CudaEvent onboardEvent; mOnboardManager.getStream().record(onboardEvent); - - mBufferManager.getStream().wait(offloadEvent); mBufferManager.getStream().wait(onboardEvent); // Once we synchronize, clear our list of pending thransfers. - mPendingOffloads.clear(); + mPendingReads.clear(); + mPendingWrites.clear(); } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 0949a3963f1..c6814078d21 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -229,6 +229,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx); } + void syncTransferManagerWithBufferManager() override + { + NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager); + } + void refreshBlocks() override { NB_OVERRIDE_PURE(refreshBlocks); @@ -466,6 +471,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, nb::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, nb::call_guard()) .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard()); nb::bind_vector(m, "CacheBlockIds") diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 88fbaf3a195..29a2ab09d66 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -232,6 +232,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx); } + void syncTransferManagerWithBufferManager() override + { + NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager); + } + void refreshBlocks() override { PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks); @@ -468,6 +473,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, py::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, nb::call_guard()) .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard()); py::enum_(m, "CacheType") diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index be9396e0d2c..96b3b61c217 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -413,6 +413,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests # allocate KV Cache + self.impl.sync_transfer_manager_with_buffer_manager() for req in context_batch: req_beam_width = req.sampling_config.beam_width if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[ @@ -440,7 +441,6 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): block_ids = self.get_cache_indices(req) self.kv_connector_manager.update_state_after_alloc( req, block_ids) - self.impl.refresh_blocks() for req in generation_batch: From 87205845ad3012d72a8dbac16d19ffe67a13c4af Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Tue, 11 Nov 2025 02:57:20 +0000 Subject: [PATCH 3/9] Bug fix Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> --- cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 49b25011e22..6336c261ee5 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -242,7 +242,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager void syncTransferManagerWithBufferManager() override { - NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager); + PYBIND11_OVERLOAD_PURE(syncTransferManagerWithBufferManager); } void refreshBlocks() override From 99653bf42cf272b9298c767e2bf7c929201190bc Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Tue, 11 Nov 2025 03:01:15 +0000 Subject: [PATCH 4/9] Another fix Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> --- cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 6336c261ee5..3d11ce4ef25 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -490,7 +490,7 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, py::call_guard()) - .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, nb::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, py::call_guard()) .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard()) From 05255ff671d1f356fd6b07f0d166b7683ad309a3 Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Tue, 11 Nov 2025 04:47:07 +0000 Subject: [PATCH 5/9] precommit run Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> --- .../tensorrt_llm/batch_manager/kvCacheTransferManager.h | 8 ++++++-- cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp | 2 +- .../nanobind/batch_manager/kvCacheManager.cpp | 3 ++- cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp | 3 ++- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h index 6d52e166dd2..00540dc671e 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h @@ -47,10 +47,14 @@ class KVCacheTransferManager std::string const& directory = ""); //! \brief Synchronize internal streams with bufferManager stream. - //! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step. + //! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the + //! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing + //! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step. void syncWithBufferManager(); - //! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method must be called after last call to KVCacheManager::addSequence in every step. + //! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode + //! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method + //! must be called after last call to KVCacheManager::addSequence in every step. void syncTransfers(); private: diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 4a18b0ebba1..28aa35c3af7 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -1339,7 +1339,7 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& return numMatchedTokens; } - + void BlockManager::syncTransferManagerWithBufferManager() { for (auto& [_, manager] : mWindowBlockManagers) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 64016676623..7a3bcae7cf1 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -486,7 +486,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) nb::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, nb::call_guard()) - .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, nb::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, + nb::call_guard()) .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard()) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 3d11ce4ef25..a94dbe7694a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -490,7 +490,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) py::call_guard()) .def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents, py::call_guard()) - .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, py::call_guard()) + .def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, + py::call_guard()) .def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard()) .def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard()) .def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard()) From 1369f0fa64926f5dcf05ab88fd1601ecbec4fae2 Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Wed, 12 Nov 2025 16:05:53 +0000 Subject: [PATCH 6/9] Fix broken pybind Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> --- cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index a94dbe7694a..6ab03315e1a 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -242,7 +242,7 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager void syncTransferManagerWithBufferManager() override { - PYBIND11_OVERLOAD_PURE(syncTransferManagerWithBufferManager); + PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager); } void refreshBlocks() override From 1cfa88cd2135ba81c10b9787a82f4a6279f5b8c2 Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Wed, 12 Nov 2025 17:29:22 +0000 Subject: [PATCH 7/9] Move refreshBlocks call to account for addToken calls Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> --- .../batch_manager/kvCacheTransferManager.cpp | 40 +++++++++---------- .../_torch/pyexecutor/resource_manager.py | 9 ++++- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index dbc3b2f93fa..7309be94378 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -211,22 +211,22 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, // Note about recording events to wait for cudaMempyAsync calls between blocks: // The memory copy involves raw memory blocks, which are pointed to by the // memory pool block index. When recording events, you must use getMemoryPoolBlockIndex() -// as the raw memory block identifier. Earlier versions of this code used getBlockId() -// when recording events, this is just wrong. getBlockId() returns the logical block id, -// which has nothing to do with the raw memory block pointers involved in a cudaMemcpy. +// as the raw memory block identifier. Using getBlockId() when recording events is wrong. +// getBlockId() returns the logical block id, which has nothing to do with the raw memory +// block pointers involved in a cudaMemcpy. // // // Notes about need for synchronization: // -// Earlier versions of this code relied on decoder implicitly syncing GPU with CPU. -// This is inherently dangerous, it is not given that decoder will always explicitly sync -// GPU with CPU for every step, a major design goal of ongoing work is to avoid this. -// To make the code future proof, we introdduce a new method SyncWithBufferManager() -// that ensures that internal copy streams will wait for prefill and decode kernels -// that have already been scheduled. +// Relying on decoder syncing GPU with CPU to ensure that blocks are ready +// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder +// that may not synchronize or synchronize at a later point in the execution stream. +// To avoid synchronization issues caused by changes to decoder design we introduce +// a new method SyncWithBufferManager() that ensures that internal copy streams +// will wait for prefill and decode kernels that have already been scheduled. // -// Earlier versions of this code did not account for all possible cases were a new block copy +// Earlier versions of this code did not account for all possible cases where a new block copy // needed to wait for a previously scheduled copy to finish. For instance, it is possible // that two primary blocks are offloaded to the same secondary block in a single step, // scheduling the second offloading without waiting for the first one to finish leads to @@ -241,16 +241,16 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, // Failing to do so will lead to corrupted blocks eventually. // -void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block, +void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block, std::vector const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode, std::string const& directory) { - // Wait for any pending writes before reading from offloadBlock - auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex()); - if (offloadBlockPendingWriteItr != mPendingWrites.end()) + // Wait for any pending writes before reading from offloadedBlock + auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex()); + if (offloadedBlockPendingWriteItr != mPendingWrites.end()) { - mOnboardManager.getStream().wait(offloadBlockPendingWriteItr->second); - // Don't erase, we are not changing state of offloadBlock + mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second); + // Don't erase, we are not changing state of offloadedBlock } // Wait for any pending reads before overwriting block auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex()); @@ -267,11 +267,11 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr cons mPendingWrites.erase(blockPendingWriteItr); } - copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory); + copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory); - // Record new pending read from offloadBlock - mPendingReads[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); - mOnboardManager.getStream().record(mPendingReads[offloadBlock->getMemoryPoolBlockIndex()]); + // Record new pending read from offloadedBlock + mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent(); + mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]); // Record new pending write to block mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent(); mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]); diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index b6d007cc421..3d7f9379ad3 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -408,8 +408,11 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): with request_context(self.is_draft, scheduled_batch): context_batch = scheduled_batch.context_requests generation_batch = scheduled_batch.generation_requests - # allocate KV Cache + + # wait for all pending work to finish before launching offload/onboarding/partial copy self.impl.sync_transfer_manager_with_buffer_manager() + + # allocate KV Cache for req in context_batch: req_beam_width = req.sampling_config.beam_width if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[ @@ -437,13 +440,15 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests): block_ids = self.get_cache_indices(req) self.kv_connector_manager.update_state_after_alloc( req, block_ids) - self.impl.refresh_blocks() for req in generation_batch: self.impl.add_token(req.py_request_id) for _ in range(get_draft_token_length(req)): self.impl.add_token(req.py_request_id) + # prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching + self.impl.refresh_blocks() + if self.kv_connector_manager is not None: self.kv_connector_manager.build_scheduler_output( scheduled_batch, self) From 03d4aac90be142488f2c4ebb36dc1f3d834916c2 Mon Sep 17 00:00:00 2001 From: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Wed, 12 Nov 2025 17:30:28 +0000 Subject: [PATCH 8/9] precommit run Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index 7309be94378..c3b16af7678 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -211,8 +211,8 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, // Note about recording events to wait for cudaMempyAsync calls between blocks: // The memory copy involves raw memory blocks, which are pointed to by the // memory pool block index. When recording events, you must use getMemoryPoolBlockIndex() -// as the raw memory block identifier. Using getBlockId() when recording events is wrong. -// getBlockId() returns the logical block id, which has nothing to do with the raw memory +// as the raw memory block identifier. Using getBlockId() when recording events is wrong. +// getBlockId() returns the logical block id, which has nothing to do with the raw memory // block pointers involved in a cudaMemcpy. // @@ -223,7 +223,7 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, // for offload/onboard/partial copy is dangerous. We have an asynchronous decoder // that may not synchronize or synchronize at a later point in the execution stream. // To avoid synchronization issues caused by changes to decoder design we introduce -// a new method SyncWithBufferManager() that ensures that internal copy streams +// a new method SyncWithBufferManager() that ensures that internal copy streams // will wait for prefill and decode kernels that have already been scheduled. // // Earlier versions of this code did not account for all possible cases where a new block copy From 9f04777ec49b3640c204fb46bb63f8d7df1e4864 Mon Sep 17 00:00:00 2001 From: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:21:25 -0600 Subject: [PATCH 9/9] Update cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com> Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> --- cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp index c3b16af7678..495f6a3ed34 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp @@ -222,8 +222,8 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst, // Relying on decoder syncing GPU with CPU to ensure that blocks are ready // for offload/onboard/partial copy is dangerous. We have an asynchronous decoder // that may not synchronize or synchronize at a later point in the execution stream. -// To avoid synchronization issues caused by changes to decoder design we introduce -// a new method SyncWithBufferManager() that ensures that internal copy streams +// To avoid synchronization issues caused by changes to decoder design we rely on +// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams // will wait for prefill and decode kernels that have already been scheduled. // // Earlier versions of this code did not account for all possible cases where a new block copy