Skip to content

Commit 1b588f8

Browse files
authored
feat: KV events for sliding window attention (NVIDIA#5580)
Signed-off-by: jthomson04 <[email protected]>
1 parent d61893d commit 1b588f8

File tree

7 files changed

+90
-23
lines changed

7 files changed

+90
-23
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ class KVCacheEventManager
4444
KVCacheEventManager(KVCacheEventManager&& other) = delete;
4545
KVCacheEventManager& operator=(KVCacheEventManager&& other) = delete;
4646

47-
void enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel);
47+
void enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize);
4848

49-
void enqueueStoredEvent(std::vector<BlockPtr> const& blocks);
49+
void enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize);
5050

51-
void enqueueRemovedEvent(BlockPtr const& block);
51+
void enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize);
5252

53-
void enqueueUpdatedEvent(executor::KVCacheUpdatedData const& data);
53+
void enqueueUpdatedEvent(executor::KVCacheUpdatedData const& data, SizeType32 windowSize);
5454

5555
// Get events in mEvents. If there are no events, wait for a maximum of `timeout` milliseconds.
5656
std::deque<executor::KVCacheEvent> getEvents(std::optional<std::chrono::milliseconds> timeout);

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1709,12 +1709,14 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC
17091709
struct KVCacheEvent
17101710
{
17111711

1712-
KVCacheEvent(IdType eventId, KVCacheEventData data);
1712+
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize);
17131713

17141714
/// @brief The unique id of this event
17151715
IdType eventId;
17161716
/// @brief The data corresponding to this event
17171717
KVCacheEventData data;
1718+
/// @brief The sliding window size
1719+
SizeType32 windowSize;
17181720
};
17191721

17201722
/// @brief Exposes a limited set of KV cache manager functionalities

cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,13 @@ KVCacheEventManager::~KVCacheEventManager()
4242
mWorkerThread.join();
4343
}
4444

45-
void KVCacheEventManager::enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel)
45+
void KVCacheEventManager::enqueueCreatedEvent(
46+
std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize)
4647
{
47-
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}});
48+
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize});
4849
}
4950

50-
void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks)
51+
void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize)
5152
{
5253
if (blocks.empty())
5354
{
@@ -67,24 +68,26 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
6768
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority());
6869
}
6970

70-
enqueueEvent({mEventId++, data});
71+
enqueueEvent({mEventId++, data, windowSize});
7172
}
7273

73-
void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block)
74+
void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize)
7475
{
75-
if (!mEventQueue.empty() && std::holds_alternative<tle::KVCacheRemovedData>(mEventQueue.back().data))
76+
// We can only batch the removed block events if the same sliding window size is used.
77+
if (!mEventQueue.empty() && mEventQueue.back().windowSize == windowSize
78+
&& std::holds_alternative<tle::KVCacheRemovedData>(mEventQueue.back().data))
7679
{
7780
std::get<tle::KVCacheRemovedData>(mEventQueue.back().data).blockHashes.push_back(block->getHash());
7881
}
7982
else
8083
{
81-
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}});
84+
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize});
8285
}
8386
}
8487

85-
void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data)
88+
void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize)
8689
{
87-
enqueueEvent({mEventId++, data});
90+
enqueueEvent({mEventId++, data, windowSize});
8891
}
8992

9093
void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event)

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind
552552
mAllBlocksById, {blocksInPrimaryPool, blocksInSecondaryPool}, secondaryOffloadMinPriority);
553553
if (mEventManager)
554554
{
555-
mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool});
555+
mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool}, mWindowSize);
556556
}
557557
}
558558

@@ -741,7 +741,7 @@ void WindowBlockManager::freeChildren(
741741
// Free block
742742
if (mEventManager && blockInRadixTree(block))
743743
{
744-
mEventManager->enqueueRemovedEvent(block);
744+
mEventManager->enqueueRemovedEvent(block, mWindowSize);
745745
}
746746

747747
claimLeafBlock(block, priority, durationMs);
@@ -776,7 +776,8 @@ BlockPtr WindowBlockManager::getFreeBlock(
776776
if (mEventManager && blockInRadixTree(block))
777777
{
778778
mEventManager->enqueueUpdatedEvent(
779-
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel));
779+
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel),
780+
mWindowSize);
780781
}
781782
mEvictionPolicy->releaseBlock(block); // append offload block to mFreeSecondaryBlocks queue
782783
block = offloadBlock;
@@ -881,7 +882,8 @@ void WindowBlockManager::onboardBlock(BlockPtr const& offloadBlock)
881882
if (mEventManager)
882883
{
883884
mEventManager->enqueueUpdatedEvent(
884-
tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel));
885+
tle::KVCacheUpdatedData(offloadBlock->getHash()).cacheLevelUpdated(kSecondaryLevel, kPrimaryLevel),
886+
mWindowSize);
885887
}
886888
mEvictionPolicy->releaseBlock(block); // append block to offload queue
887889
// offloadBlock is now in primary memory pool
@@ -908,7 +910,8 @@ void WindowBlockManager::offloadBlock(BlockPtr const& block)
908910
if (mEventManager && blockInRadixTree(block))
909911
{
910912
mEventManager->enqueueUpdatedEvent(
911-
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel));
913+
tle::KVCacheUpdatedData(block->getHash()).cacheLevelUpdated(kPrimaryLevel, kSecondaryLevel),
914+
mWindowSize);
912915
}
913916
mEvictionPolicy->releaseBlock(offloadBlock); // append offloadBlock to mFreePrimaryBlocks queue
914917
// block is now in secondary memory
@@ -980,7 +983,8 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector<BlockKey> const&
980983
{
981984
mEventManager->enqueueUpdatedEvent(
982985
tle::KVCacheUpdatedData(matchingBlock->getHash())
983-
.priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority));
986+
.priorityUpdated(matchingBlock->getPriority(), *perBlockRetentions[bi].retentionPriority),
987+
mWindowSize);
984988
}
985989
if (partialMatch)
986990
{
@@ -1275,7 +1279,7 @@ void WindowBlockManager::storeBlocks(
12751279
}
12761280
if (mEventManager)
12771281
{
1278-
mEventManager->enqueueStoredEvent(storedBlocks);
1282+
mEventManager->enqueueStoredEvent(storedBlocks, mWindowSize);
12791283
}
12801284
}
12811285

cpp/tensorrt_llm/executor/executor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,10 @@ std::optional<std::shared_ptr<KVCacheEventManager>> Executor::getKVCacheEventMan
132132
return mImpl->getKVCacheEventManager();
133133
}
134134

135-
KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data)
135+
KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize)
136136
: eventId{eventId}
137137
, data{std::move(data)}
138+
, windowSize{windowSize}
138139
{
139140
}
140141

cpp/tensorrt_llm/pybind/executor/bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ void initBindings(pybind11::module_& m)
239239

240240
py::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent")
241241
.def_readonly("event_id", &tle::KVCacheEvent::eventId)
242-
.def_readonly("data", &tle::KVCacheEvent::data);
242+
.def_readonly("data", &tle::KVCacheEvent::data)
243+
.def_readonly("window_size", &tle::KVCacheEvent::windowSize);
243244

244245
py::class_<tle::KVCacheEventManager, std::shared_ptr<tle::KVCacheEventManager>>(
245246
executor_kv_cache, "KVCacheEventManager")

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,6 +3401,62 @@ TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamBlocking)
34013401
EXPECT_TRUE(std::holds_alternative<tle::KVCacheStoredData>(events.front().data));
34023402
}
34033403

3404+
TEST_F(KVCacheManagerTest, KVCacheManagerEventStreamWindowSize)
3405+
{
3406+
auto constexpr numLayers = 2;
3407+
auto constexpr numHeads = 6;
3408+
auto constexpr sizePerHead = 16;
3409+
auto constexpr tokensPerBlock = 4;
3410+
auto constexpr maxBlocksPerSeq = 4;
3411+
auto constexpr maxNumSequences = 8;
3412+
auto blocksInPool = std::vector<SizeType32>{8, 2};
3413+
auto blocksInSlidingWindowPool = std::vector<SizeType32>{4, 2};
3414+
auto constexpr onboardBlocks = true;
3415+
auto constexpr dtype = nvinfer1::DataType::kHALF;
3416+
auto const stream = std::make_shared<tr::CudaStream>();
3417+
3418+
auto constexpr beamWidth = 1;
3419+
SizeType32 constexpr maxNewTokens{0};
3420+
tr::SamplingConfig const samplingConfig{beamWidth};
3421+
bool constexpr isStreaming{false};
3422+
3423+
auto const maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq;
3424+
auto const slidingWindow = tokensPerBlock * (maxBlocksPerSeq - 1);
3425+
3426+
auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPool[0], blocksInPool[1]}},
3427+
{slidingWindow, {blocksInSlidingWindowPool[0], blocksInSlidingWindowPool[1]}}};
3428+
3429+
KVCacheManager kvCacheManager(numLayers, numHeads, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences,
3430+
beamWidth, std::vector<BlockManager::SizeType32>{maxAttentionWindow, slidingWindow}, std::nullopt, dtype, 0,
3431+
stream, std::nullopt, true, onboardBlocks, CacheType::kSELF, std::nullopt,
3432+
std::make_unique<tlk::KVCacheEventManager>(1024));
3433+
kvCacheManager.allocatePools(false);
3434+
3435+
auto events = getEvents(kvCacheManager);
3436+
3437+
EXPECT_EQ(events.size(), 2);
3438+
3439+
EXPECT_EQ(events.front().windowSize, slidingWindow);
3440+
EXPECT_EQ(std::get<tle::KVCacheCreatedData>(events.front().data).numBlocksPerCacheLevel, blocksInSlidingWindowPool);
3441+
3442+
EXPECT_EQ(events.back().windowSize, maxAttentionWindow);
3443+
EXPECT_EQ(std::get<tle::KVCacheCreatedData>(events.back().data).numBlocksPerCacheLevel, blocksInPool);
3444+
3445+
auto inputTokens0 = std::make_shared<VecTokens>(VecTokens{0, 1, 2, 3, 4, 5, 6, 7});
3446+
auto llmRequest0 = std::make_shared<LlmRequest>(0, 0, inputTokens0, samplingConfig, true);
3447+
kvCacheManager.addSequence(0, inputTokens0->size(), beamWidth, llmRequest0);
3448+
kvCacheManager.storeContextBlocks(*llmRequest0);
3449+
3450+
events = getEvents(kvCacheManager);
3451+
3452+
EXPECT_EQ(events.size(), 2);
3453+
EXPECT_EQ(events.front().windowSize, slidingWindow);
3454+
EXPECT_TRUE(std::holds_alternative<tle::KVCacheStoredData>(events.front().data));
3455+
3456+
EXPECT_EQ(events.back().windowSize, maxAttentionWindow);
3457+
EXPECT_TRUE(std::holds_alternative<tle::KVCacheStoredData>(events.back().data));
3458+
}
3459+
34043460
TEST_F(KVCacheManagerTest, KVCacheTransferManagerConcurrencyTest)
34053461
{
34063462
auto const blockSize = 16384;

0 commit comments

Comments
 (0)