Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,26 @@ class KVCacheBlock

size_t getHash() const;

//! \brief Set onboard event to track asynchronous block transfer completion.
//! \param event CUDA event to associate with this block (moved into the block)
void setPendingOnboardEvent(runtime::CudaEvent&& event)
{
mPendingOnboardEvent = std::move(event);
}

//! \brief Get the pending onboard event if one exists.
//! \return Pointer to the pending event, or nullptr if no event is pending
runtime::CudaEvent const* getPendingOnboardEvent() const
{
return mPendingOnboardEvent ? &mPendingOnboardEvent.value() : nullptr;
}

//! \brief Clear the pending onboard event
void clearPendingOnboardEvent()
{
mPendingOnboardEvent.reset();
}

private:
// Linear ID of block independent of pool
IdType mBlockId;
Expand Down Expand Up @@ -365,6 +385,8 @@ class KVCacheBlock
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
// Hash for the event manager
size_t mHash;
// Possible pending event to onboard the block
std::optional<runtime::CudaEvent> mPendingOnboardEvent;
};

class GenerationRequest
Expand Down
7 changes: 7 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,13 @@ void WindowBlockManager::addSequence(

void WindowBlockManager::addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx)
{
if (auto const* onboardEvent = block->getPendingOnboardEvent())
{
// Make sure block is onboarded before used
mBufferManager.getStream().wait(*onboardEvent);
block->clearPendingOnboardEvent();
}

auto const requestId = sequence.getRequestId();
block->incRefCount();
if (sequence.getCacheBlockIds(mWindowSize).at(beamIdx).size() == 0)
Expand Down
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ void KVCacheTransferManager::onboard(BlockPtr const& offloadBlock, BlockPtr cons
mOnboardManager.getStream().wait(mPendingOffloads[offloadBlock->getBlockId()]);
}
copyBlock(offloadBlock, block, pools, false, numTokensToCopy, mode, directory);
tr::CudaEvent onboardEvent;
mOnboardManager.getStream().record(onboardEvent);
block->setPendingOnboardEvent(std::move(onboardEvent));
}

void KVCacheTransferManager::offload(BlockPtr const& block, BlockPtr const& offloadBlock,
Expand Down