Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1d484eb
Add refresh blocks
Tabrizian Nov 3, 2025
307479b
Fix transfer manager synchronization issues
thorjohnsen Nov 6, 2025
c4dc529
Fix merge issues
thorjohnsen Nov 11, 2025
8720584
Bug fix
thorjohnsen Nov 11, 2025
99653bf
Another fix
thorjohnsen Nov 11, 2025
ffb90a4
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 11, 2025
b471cdf
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 11, 2025
05255ff
precommit run
thorjohnsen Nov 11, 2025
9ac9534
Merge branch 'user/tjohnsen/fix_5627710' of github.com:thorjohnsen/Te…
thorjohnsen Nov 11, 2025
9aef49c
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 11, 2025
1369f0f
Fix broken pybind
thorjohnsen Nov 12, 2025
559beb0
Merge branch 'user/tjohnsen/fix_5627710' of github.com:thorjohnsen/Te…
thorjohnsen Nov 12, 2025
1cfa88c
Move refreshBlocks call to account for addToken calls
thorjohnsen Nov 12, 2025
03d4aac
precommit run
thorjohnsen Nov 12, 2025
262d34c
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 13, 2025
f4ae208
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 13, 2025
9f04777
Update cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
thorjohnsen Nov 13, 2025
ba32c3b
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 13, 2025
748cbaf
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 14, 2025
279b038
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Nov 17, 2025
c6ba7b6
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 17, 2025
b34ea28
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Nov 20, 2025
c248977
Merge branch 'user/tjohnsen/fix_5627710' of github.com:thorjohnsen/Te…
thorjohnsen Nov 20, 2025
55a274c
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 20, 2025
0505ef3
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Nov 20, 2025
aab02d2
Merge remote-tracking branch 'upstream/main' into user/tjohnsen/fix_5…
thorjohnsen Nov 21, 2025
5867991
Merge branch 'main' into user/tjohnsen/fix_5627710
thorjohnsen Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ class WindowBlockManager
return mBufferManager;
}

//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();

//! \brief Perform per-request bookkeeping
void refreshBlocks();

Expand Down Expand Up @@ -1283,6 +1286,9 @@ class BlockManager
//! \brief Store newest block for reuse
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);

//! \brief Sync internal streams used by transfer manager with buffer manager stream
void syncTransferManagerWithBufferManager();

//! \brief Perform per-request bookkeeping
void refreshBlocks();

Expand Down Expand Up @@ -1546,6 +1552,7 @@ class BaseKVCacheManager
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;

virtual void syncTransferManagerWithBufferManager() = 0;
virtual void refreshBlocks() = 0;
virtual void flushIterationEvents() = 0;
virtual void resetReuseState() = 0;
Expand Down Expand Up @@ -1912,6 +1919,11 @@ class KVCacheManager : public BaseKVCacheManager
return mBlockManager.getPoolLayerIdx(layer_idx);
}

void syncTransferManagerWithBufferManager() override
{
mBlockManager.syncTransferManagerWithBufferManager();
}

//! \brief Perform per-iteration bookkeeping
void refreshBlocks() override
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ class KVCacheTransferManager
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
std::string const& directory = "");

//! \brief Synchronize the offload/onboard streams with the bufferManager stream.
//! \brief Synchronize internal streams with bufferManager stream.
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
void syncWithBufferManager();

//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method must be called after last call to KVCacheManager::addSequence in every step.
void syncTransfers();

private:
Expand Down Expand Up @@ -75,8 +79,10 @@ class KVCacheTransferManager
runtime::BufferManager mOnboardManager;
runtime::BufferManager mOffloadManager;

// Track the block ids offloaded in this iteration.
std::unordered_map<int32_t, tr::CudaEvent> mPendingOffloads;
// Track reads and writes for blocks. Note that it is the memory pool index that
// identifies the raw memory blocks involved in I/O, not the block Id.
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
// Reference to parent loopback agent
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
int mDeviceId;
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(allocateKvCache);

kvCacheManager.syncTransferManagerWithBufferManager();

for (auto const& llmReq : contextRequests)
{
if (llmReq->isFirstContextChunk())
Expand Down
13 changes: 13 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&

return numMatchedTokens;
}

void BlockManager::syncTransferManagerWithBufferManager()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.syncTransferManagerWithBufferManager();
}
}

void WindowBlockManager::syncTransferManagerWithBufferManager()
{
mTransferManager->syncWithBufferManager();
}

void BlockManager::refreshBlocks()
{
Expand Down
119 changes: 106 additions & 13 deletions cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
}
}

//
// Note about recording events to wait for cudaMempyAsync calls between blocks:
// The memory copy involves raw memory blocks, which are pointed to by the
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
// as the raw memory block identifier. Earlier versions of this code used getBlockId()
// when recording events, this is just wrong. getBlockId() returns the logical block id,
// which has nothing to do with the raw memory block pointers involved in a cudaMemcpy.
//

//
// Notes about need for synchronization:
//
// Earlier versions of this code relied on decoder implicitly syncing GPU with CPU.
// This is inherently dangerous, it is not given that decoder will always explicitly sync
// GPU with CPU for every step, a major design goal of ongoing work is to avoid this.
// To make the code future proof, we introdduce a new method SyncWithBufferManager()
// that ensures that internal copy streams will wait for prefill and decode kernels
// that have already been scheduled.
//
// Earlier versions of this code did not account for all possible cases were a new block copy
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
// that two primary blocks are offloaded to the same secondary block in a single step,
// scheduling the second offloading without waiting for the first one to finish leads to
// a corrupted block after offloading. It is possible that partial reuse will copy
// from a block that is currently being onboarded, scheduling the partial copy without
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
// possible cases needing synchronization we record separate events for reads and writes
// to a block. When a new block copy is scheduled, we wait for all writes to the source
// block and all reads and writes to a destination block.
//
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
// Failing to do so will lead to corrupted blocks eventually.
//

void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
if (mode != executor::KvCacheTransferMode::DRAM
&& mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end())
// Wait for any pending writes before reading from offloadBlock
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingWriteItr != mPendingWrites.end())
{
TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk",
offloadBlock->getBlockId());
return;
mOnboardManager.getStream().wait(offloadBlockPendingWriteItr->second);
// Don't erase, we are not changing state of offloadBlock
}

if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end())
// Wait for any pending reads before overwriting block
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
if (blockPendingReadItr != mPendingReads.end())
{
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
mOnboardManager.getStream().wait(blockPendingReadItr->second);
mPendingReads.erase(blockPendingReadItr);
}
// Wait for any pending writes before overwriting block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
mPendingWrites.erase(blockPendingWriteItr);
}

copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);

// Record new pending read from offloadBlock
mPendingReads[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingReads[offloadBlock->getMemoryPoolBlockIndex()]);
// Record new pending write to block
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
}

void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
std::string const& directory)
{
mPendingOffloads[block->getBlockId()] = tr::CudaEvent();
// Wait for any pending writes before reading from block
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
if (blockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
// Don't erase, we are not changing state of block
}
// Wait for any pending reads before overwriting offloadBlock
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingReadItr != mPendingReads.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
mPendingReads.erase(offloadBlockPendingReadItr);
}
// Wait for any pending writes before overwriting offloadBlock
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
if (offloadBlockPendingWriteItr != mPendingWrites.end())
{
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
mPendingWrites.erase(offloadBlockPendingWriteItr);
}

copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]);

// Record new pending read from block
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
// Record new pending write to offloadBlock
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
}

void KVCacheTransferManager::syncWithBufferManager()
{
tr::CudaEvent readyForOffloadEvent;
mBufferManager.getStream().record(readyForOffloadEvent);
mOffloadManager.getStream().wait(readyForOffloadEvent);

tr::CudaEvent readyForOnboardEvent;
mBufferManager.getStream().record(readyForOnboardEvent);
mOnboardManager.getStream().wait(readyForOnboardEvent);

// Once we synchronize, clear our list of pending thransfers.
mPendingReads.clear();
mPendingWrites.clear();
}

void KVCacheTransferManager::syncTransfers()
{
tr::CudaEvent offloadEvent;
mOffloadManager.getStream().record(offloadEvent);
mBufferManager.getStream().wait(offloadEvent);

tr::CudaEvent onboardEvent;
mOnboardManager.getStream().record(onboardEvent);

mBufferManager.getStream().wait(offloadEvent);
mBufferManager.getStream().wait(onboardEvent);

// Once we synchronize, clear our list of pending thransfers.
mPendingOffloads.clear();
mPendingReads.clear();
mPendingWrites.clear();
}

} // namespace tensorrt_llm::batch_manager::kv_cache_manager
7 changes: 7 additions & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
}

void syncTransferManagerWithBufferManager() override
{
NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager);
}

void refreshBlocks() override
{
NB_OVERRIDE_PURE(refreshBlocks);
Expand Down Expand Up @@ -481,6 +486,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
nb::call_guard<nb::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
nb::call_guard<nb::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, nb::call_guard<nb::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());
Expand Down
7 changes: 7 additions & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
}

void syncTransferManagerWithBufferManager() override
{
PYBIND11_OVERLOAD_PURE(syncTransferManagerWithBufferManager);
}

void refreshBlocks() override
{
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
Expand Down Expand Up @@ -485,6 +490,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
py::call_guard<py::gil_scoped_release>())
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
py::call_guard<py::gil_scoped_release>())
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager, nb::call_guard<nb::gil_scoped_release>())
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
context_batch = scheduled_batch.context_requests
generation_batch = scheduled_batch.generation_requests
# allocate KV Cache
self.impl.sync_transfer_manager_with_buffer_manager()
for req in context_batch:
req_beam_width = req.sampling_config.beam_width
if 'cp_type' in self.mapping.cp_config and CpType.STAR == self.mapping.cp_config[
Expand Down Expand Up @@ -436,6 +437,7 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
block_ids = self.get_cache_indices(req)
self.kv_connector_manager.update_state_after_alloc(
req, block_ids)
self.impl.refresh_blocks()

for req in generation_batch:
self.impl.add_token(req.py_request_id)
Expand Down
Loading