Skip to content

Commit 524ad7b

Browse files
committed
[#8813][fix] Add missing event for block onboard for the kv cache transfer manager
Authored-by: @josephrocca Co-authored-by: eopXD <[email protected]> Signed-off-by: eopXD <[email protected]>
1 parent 497a070 commit 524ad7b

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,26 @@ class KVCacheBlock
325325

326326
size_t getHash() const;
327327

328+
//! \brief Set onboard event to track asynchronous block transfer completion.
329+
//! \param event CUDA event to associate with this block (moved into the block)
330+
void setPendingOnboardEvent(runtime::CudaEvent&& event)
331+
{
332+
mPendingOnboardEvent = std::move(event);
333+
}
334+
335+
//! \brief Get the pending onboard event if one exists.
336+
//! \return Pointer to the pending event, or nullptr if no event is pending
337+
runtime::CudaEvent const* getPendingOnboardEvent() const
338+
{
339+
return mPendingOnboardEvent ? &mPendingOnboardEvent.value() : nullptr;
340+
}
341+
342+
//! \brief Clear the pending onboard event
343+
void clearPendingOnboardEvent()
344+
{
345+
mPendingOnboardEvent.reset();
346+
}
347+
328348
private:
329349
// Linear ID of block independent of pool
330350
IdType mBlockId;
@@ -365,6 +385,8 @@ class KVCacheBlock
365385
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
366386
// Hash for the event manager
367387
size_t mHash;
388+
// Possible pending event to onboard the block
389+
std::optional<runtime::CudaEvent> mPendingOnboardEvent;
368390
};
369391

370392
class GenerationRequest

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,13 @@ void WindowBlockManager::addSequence(
14831483

14841484
void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx)
14851485
{
1486+
if (auto const* onboardEvent = block->getPendingOnboardEvent())
1487+
{
1488+
// Make sure block is onboarded before used
1489+
mBufferManager.getStream().wait(*onboardEvent);
1490+
block->clearPendingOnboardEvent();
1491+
}
1492+
14861493
auto const requestId = sequence.getRequestId();
14871494
block->incRefCount();
14881495
if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0)

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr cons
224224
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
225225
}
226226
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);
227+
tr::CudaEvent onboardEvent;
228+
mOnboardManager.getStream().record(onboardEvent);
229+
block->setPendingOnboardEvent(std::move(onboardEvent));
227230
}
228231

229232
void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,

0 commit comments

Comments
 (0)