Skip to content

Commit 95049ee

Browse files
thorjohnsenTabrizianFunatiq
authored
[https://nvbugs/5627710][fix] Fix synchronization bugs in KvCacheTransferManager that can cause corrupted blocks (#9056)
Signed-off-by: thorjohnsen <41591019+thorjohnsen@users.noreply.github.com> Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com> Co-authored-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com> Co-authored-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent b86256e commit 95049ee

File tree

8 files changed

+171
-18
lines changed

8 files changed

+171
-18
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,9 @@ class WindowBlockManager
824824
return mBufferManager;
825825
}
826826

827+
//! \brief Sync internal streams used by transfer manager with buffer manager stream
828+
void syncTransferManagerWithBufferManager();
829+
827830
//! \brief Perform per-request bookkeeping
828831
void refreshBlocks();
829832

@@ -1313,6 +1316,9 @@ class BlockManager
13131316
//! \brief Store newest block for reuse
13141317
void storeNewBlock(GenerationRequest& sequence, OptionalRef<LlmRequest const> llmRequest);
13151318

1319+
//! \brief Sync internal streams used by transfer manager with buffer manager stream
1320+
void syncTransferManagerWithBufferManager();
1321+
13161322
//! \brief Perform per-request bookkeeping
13171323
void refreshBlocks();
13181324

@@ -1584,6 +1590,7 @@ class BaseKVCacheManager
15841590
[[nodiscard]] virtual runtime::ITensor::SharedPtr getIndexerKCachePool() const = 0;
15851591
[[nodiscard]] virtual SizeType32 getPoolLayerIdx(SizeType32 layer_idx) const = 0;
15861592

1593+
virtual void syncTransferManagerWithBufferManager() = 0;
15871594
virtual void refreshBlocks() = 0;
15881595
virtual void flushIterationEvents() = 0;
15891596
virtual void resetReuseState() = 0;
@@ -1965,6 +1972,11 @@ class KVCacheManager : public BaseKVCacheManager
19651972
return mBlockManager.getPoolLayerIdx(layer_idx);
19661973
}
19671974

1975+
void syncTransferManagerWithBufferManager() override
1976+
{
1977+
mBlockManager.syncTransferManagerWithBufferManager();
1978+
}
1979+
19681980
//! \brief Perform per-iteration bookkeeping
19691981
void refreshBlocks() override
19701982
{

cpp/include/tensorrt_llm/batch_manager/kvCacheTransferManager.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,15 @@ class KVCacheTransferManager
4646
int numTokensToCopy = 0, executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM,
4747
std::string const& directory = "");
4848

49-
//! \brief Synchronize the offload/onboard streams with the bufferManager stream.
49+
//! \brief Synchronize internal streams with bufferManager stream.
50+
//! \details The buffer manager uses the same stream as the prefill and decode kernels. This method ensures that the
51+
//! internal kernels used for offloading and onboarding will wait for prefill and decode kernels before performing
52+
//! any block copies. This method must be called before the first call to KVCacheManager::addSequence in every step.
53+
void syncWithBufferManager();
54+
55+
//! \brief Synchronize bufferManager stream with internal streams. This method ensures that prefill and decode
56+
//! kernels for next step will wait for offloading and onboarding work that has already been scheduled. This method
57+
//! must be called after last call to KVCacheManager::addSequence in every step.
5058
void syncTransfers();
5159

5260
private:
@@ -75,8 +83,10 @@ class KVCacheTransferManager
7583
runtime::BufferManager mOnboardManager;
7684
runtime::BufferManager mOffloadManager;
7785

78-
// Track the block ids offloaded in this iteration.
79-
std::unordered_map<int32_t, tr::CudaEvent> mPendingOffloads;
86+
// Track reads and writes for blocks. Note that it is the memory pool index that
87+
// identifies the raw memory blocks involved in I/O, not the block Id.
88+
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingReads;
89+
std::unordered_map<kernels::KVCacheIndex::UnderlyingType, tr::CudaEvent> mPendingWrites;
8090
// Reference to parent loopback agent
8191
std::shared_ptr<kvc::BaseLoopbackAgent> mLoopbackAgent;
8292
int mDeviceId;

cpp/tensorrt_llm/batch_manager/allocateKvCache.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ void tensorrt_llm::batch_manager::AllocateKvCache::operator()(BaseKVCacheManager
2626
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
2727
NVTX3_SCOPED_RANGE(allocateKvCache);
2828

29+
kvCacheManager.syncTransferManagerWithBufferManager();
30+
2931
for (auto const& llmReq : contextRequests)
3032
{
3133
if (llmReq->isFirstContextChunk())

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,19 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
13361336
return numMatchedTokens;
13371337
}
13381338

1339+
void BlockManager::syncTransferManagerWithBufferManager()
1340+
{
1341+
for (auto& [_, manager] : mWindowBlockManagers)
1342+
{
1343+
manager.syncTransferManagerWithBufferManager();
1344+
}
1345+
}
1346+
1347+
void WindowBlockManager::syncTransferManagerWithBufferManager()
1348+
{
1349+
mTransferManager->syncWithBufferManager();
1350+
}
1351+
13391352
void BlockManager::refreshBlocks()
13401353
{
13411354
for (auto& [_, manager] : mWindowBlockManagers)

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 108 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -207,47 +207,140 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
207207
}
208208
}
209209

210-
void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr const& block,
210+
//
211+
// Note about recording events to wait for cudaMempyAsync calls between blocks:
212+
// The memory copy involves raw memory blocks, which are pointed to by the
213+
// memory pool block index. When recording events, you must use getMemoryPoolBlockIndex()
214+
// as the raw memory block identifier. Using getBlockId() when recording events is wrong.
215+
// getBlockId() returns the logical block id, which has nothing to do with the raw memory
216+
// block pointers involved in a cudaMemcpy.
217+
//
218+
219+
//
220+
// Notes about need for synchronization:
221+
//
222+
// Relying on decoder syncing GPU with CPU to ensure that blocks are ready
223+
// for offload/onboard/partial copy is dangerous. We have an asynchronous decoder
224+
// that may not synchronize or synchronize at a later point in the execution stream.
225+
// To avoid synchronization issues caused by changes to decoder design we rely on
226+
// KVCacheTransferManager::syncWithBufferManager() that ensures that internal copy streams
227+
// will wait for prefill and decode kernels that have already been scheduled.
228+
//
229+
// Earlier versions of this code did not account for all possible cases where a new block copy
230+
// needed to wait for a previously scheduled copy to finish. For instance, it is possible
231+
// that two primary blocks are offloaded to the same secondary block in a single step,
232+
// scheduling the second offloading without waiting for the first one to finish leads to
233+
// a corrupted block after offloading. It is possible that partial reuse will copy
234+
// from a block that is currently being onboarded, scheduling the partial copy without
235+
// waiting for the onboarding to finish will lead to a corrupted block. To handle all
236+
// possible cases needing synchronization we record separate events for reads and writes
237+
// to a block. When a new block copy is scheduled, we wait for all writes to the source
238+
// block and all reads and writes to a destination block.
239+
//
240+
// As before, syncTransfers() must be called after last call to KVCacheManager::addSequence.
241+
// Failing to do so will lead to corrupted blocks eventually.
242+
//
243+
244+
void KVCacheTransferManager::onboard(BlockPtr const& offloadedBlock, BlockPtr const& block,
211245
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
212246
std::string const& directory)
213247
{
214-
if (mode != executor::KvCacheTransferMode::DRAM
215-
&& mPendingOffloads.find(offloadBlock->getBlockId()) == mPendingOffloads.end())
248+
// Wait for any pending writes before reading from offloadedBlock
249+
auto offloadedBlockPendingWriteItr = mPendingWrites.find(offloadedBlock->getMemoryPoolBlockIndex());
250+
if (offloadedBlockPendingWriteItr != mPendingWrites.end())
216251
{
217-
TLLM_LOG_DEBUG("Skipping onboard for block %d because it was never previously offloaded to disk",
218-
offloadBlock->getBlockId());
219-
return;
252+
mOnboardManager.getStream().wait(offloadedBlockPendingWriteItr->second);
253+
// Don't erase, we are not changing state of offloadedBlock
220254
}
221-
222-
if (mPendingOffloads.find(offloadBlock->getBlockId()) != mPendingOffloads.end())
255+
// Wait for any pending reads before overwriting block
256+
auto blockPendingReadItr = mPendingReads.find(block->getMemoryPoolBlockIndex());
257+
if (blockPendingReadItr != mPendingReads.end())
258+
{
259+
mOnboardManager.getStream().wait(blockPendingReadItr->second);
260+
mPendingReads.erase(blockPendingReadItr);
261+
}
262+
// Wait for any pending writes before overwriting block
263+
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
264+
if (blockPendingWriteItr != mPendingWrites.end())
223265
{
224-
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
266+
mOnboardManager.getStream().wait(blockPendingWriteItr->second);
267+
mPendingWrites.erase(blockPendingWriteItr);
225268
}
226-
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);
269+
270+
copyBlock(offloadedBlock, block, pools, false, numTokensToCopy, mode, directory);
271+
272+
// Record new pending read from offloadedBlock
273+
mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
274+
mOnboardManager.getStream().record(mPendingReads[offloadedBlock->getMemoryPoolBlockIndex()]);
275+
// Record new pending write to block
276+
mPendingWrites[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
277+
mOnboardManager.getStream().record(mPendingWrites[block->getMemoryPoolBlockIndex()]);
227278
}
228279

229280
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
230281
std::vector<KVCacheBlockPool> const& pools, int numTokensToCopy, executor::KvCacheTransferMode mode,
231282
std::string const& directory)
232283
{
233-
mPendingOffloads[block->getBlockId()] = tr::CudaEvent();
284+
// Wait for any pending writes before reading from block
285+
auto blockPendingWriteItr = mPendingWrites.find(block->getMemoryPoolBlockIndex());
286+
if (blockPendingWriteItr != mPendingWrites.end())
287+
{
288+
mOffloadManager.getStream().wait(blockPendingWriteItr->second);
289+
// Don't erase, we are not changing state of block
290+
}
291+
// Wait for any pending reads before overwriting offloadBlock
292+
auto offloadBlockPendingReadItr = mPendingReads.find(offloadBlock->getMemoryPoolBlockIndex());
293+
if (offloadBlockPendingReadItr != mPendingReads.end())
294+
{
295+
mOffloadManager.getStream().wait(offloadBlockPendingReadItr->second);
296+
mPendingReads.erase(offloadBlockPendingReadItr);
297+
}
298+
// Wait for any pending writes before overwriting offloadBlock
299+
auto offloadBlockPendingWriteItr = mPendingWrites.find(offloadBlock->getMemoryPoolBlockIndex());
300+
if (offloadBlockPendingWriteItr != mPendingWrites.end())
301+
{
302+
mOffloadManager.getStream().wait(offloadBlockPendingWriteItr->second);
303+
mPendingWrites.erase(offloadBlockPendingWriteItr);
304+
}
305+
234306
copyBlock(block, offloadBlock, pools, true, numTokensToCopy, mode, directory);
235-
mOffloadManager.getStream().record(mPendingOffloads[block->getBlockId()]);
307+
308+
// Record new pending read from block
309+
mPendingReads[block->getMemoryPoolBlockIndex()] = tr::CudaEvent();
310+
mOffloadManager.getStream().record(mPendingReads[block->getMemoryPoolBlockIndex()]);
311+
// Record new pending write to offloadBlock
312+
mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()] = tr::CudaEvent();
313+
mOffloadManager.getStream().record(mPendingWrites[offloadBlock->getMemoryPoolBlockIndex()]);
314+
}
315+
316+
void KVCacheTransferManager::syncWithBufferManager()
317+
{
318+
tr::CudaEvent readyForOffloadEvent;
319+
mBufferManager.getStream().record(readyForOffloadEvent);
320+
mOffloadManager.getStream().wait(readyForOffloadEvent);
321+
322+
tr::CudaEvent readyForOnboardEvent;
323+
mBufferManager.getStream().record(readyForOnboardEvent);
324+
mOnboardManager.getStream().wait(readyForOnboardEvent);
325+
326+
// Once we synchronize, clear our list of pending thransfers.
327+
mPendingReads.clear();
328+
mPendingWrites.clear();
236329
}
237330

238331
void KVCacheTransferManager::syncTransfers()
239332
{
240333
tr::CudaEvent offloadEvent;
241334
mOffloadManager.getStream().record(offloadEvent);
335+
mBufferManager.getStream().wait(offloadEvent);
242336

243337
tr::CudaEvent onboardEvent;
244338
mOnboardManager.getStream().record(onboardEvent);
245-
246-
mBufferManager.getStream().wait(offloadEvent);
247339
mBufferManager.getStream().wait(onboardEvent);
248340

249341
// Once we synchronize, clear our list of pending thransfers.
250-
mPendingOffloads.clear();
342+
mPendingReads.clear();
343+
mPendingWrites.clear();
251344
}
252345

253346
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
235235
NB_OVERRIDE_PURE(getPoolLayerIdx, layer_idx);
236236
}
237237

238+
void syncTransferManagerWithBufferManager() override
239+
{
240+
NB_OVERRIDE_PURE(syncTransferManagerWithBufferManager);
241+
}
242+
238243
void refreshBlocks() override
239244
{
240245
NB_OVERRIDE_PURE(refreshBlocks);
@@ -481,6 +486,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m)
481486
nb::call_guard<nb::gil_scoped_release>())
482487
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
483488
nb::call_guard<nb::gil_scoped_release>())
489+
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
490+
nb::call_guard<nb::gil_scoped_release>())
491+
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, nb::call_guard<nb::gil_scoped_release>())
484492
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, nb::call_guard<nb::gil_scoped_release>())
485493
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, nb::call_guard<nb::gil_scoped_release>())
486494
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, nb::call_guard<nb::gil_scoped_release>());

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager
240240
PYBIND11_OVERLOAD_PURE(SizeType32, tbk::BaseKVCacheManager, getPoolLayerIdx, layer_idx);
241241
}
242242

243+
void syncTransferManagerWithBufferManager() override
244+
{
245+
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, syncTransferManagerWithBufferManager);
246+
}
247+
243248
void refreshBlocks() override
244249
{
245250
PYBIND11_OVERLOAD_PURE(void, tbk::BaseKVCacheManager, refreshBlocks);
@@ -485,6 +490,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
485490
py::call_guard<py::gil_scoped_release>())
486491
.def("flush_iteration_events", &BaseKVCacheManager::flushIterationEvents,
487492
py::call_guard<py::gil_scoped_release>())
493+
.def("sync_transfer_manager_with_buffer_manager", &BaseKVCacheManager::syncTransferManagerWithBufferManager,
494+
py::call_guard<py::gil_scoped_release>())
495+
.def("refresh_blocks", &BaseKVCacheManager::refreshBlocks, py::call_guard<py::gil_scoped_release>())
488496
.def("get_last_block_id", &BaseKVCacheManager::getLastBlockId, py::call_guard<py::gil_scoped_release>())
489497
.def("unpin_blocks_by_id", &BaseKVCacheManager::unpinBlocksById, py::call_guard<py::gil_scoped_release>())
490498
.def("reset_reuse_state", &BaseKVCacheManager::resetReuseState, py::call_guard<py::gil_scoped_release>());

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,10 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
434434
with request_context(self.is_draft, scheduled_batch):
435435
context_batch = scheduled_batch.context_requests
436436
generation_batch = scheduled_batch.generation_requests
437+
438+
# wait for all pending work to finish before launching offload/onboarding/partial copy
439+
self.impl.sync_transfer_manager_with_buffer_manager()
440+
437441
# allocate KV Cache
438442
for req in context_batch:
439443
req_beam_width = req.sampling_config.beam_width
@@ -475,6 +479,9 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
475479
for _ in range(get_draft_token_length(req)):
476480
self.impl.add_token(req.py_request_id)
477481

482+
# prefill and generation kernels wait for scheduled offload/onboard/partial copy work before launching
483+
self.impl.refresh_blocks()
484+
478485
if self.kv_connector_manager is not None:
479486
self.kv_connector_manager.build_scheduler_output(
480487
scheduled_batch, self)

0 commit comments

Comments
 (0)