-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathkvCacheManager.cpp
More file actions
executable file
·3220 lines (2895 loc) · 136 KB
/
kvCacheManager.cpp
File metadata and controls
executable file
·3220 lines (2895 loc) · 136 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/evictionPolicy.h"
#include "tensorrt_llm/batch_manager/kvCacheTransferManager.h"
#include "tensorrt_llm/batch_manager/radixBlockTree.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <algorithm>
#include <limits>
#include <map>
#include <optional>
#include <utility>
namespace tc = tensorrt_llm::common;
namespace tk = tensorrt_llm::kernels;
namespace tle = tensorrt_llm::executor;
using namespace tle::kv_cache;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
using namespace tensorrt_llm::batch_manager::eviction_policy;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
namespace
{
//! \brief Get all blocks in a sequence by traversing backwards from the last block.
//! \param lastBlock is a BlockPtr to the last block in the sequence to start traversal from
//! \return Vector of BlockPtr-s in sequence order
std::vector<BlockPtr> getAllSequenceBlocks(BlockPtr lastBlock)
{
// First count the number of blocks to pre-allocate the vector
auto currentBlock = lastBlock;
size_t blockCount = 0;
while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
blockCount++;
currentBlock = currentBlock->getPrevBlockInSeq();
}
if (blockCount == 0)
{
return {};
}
// Create and pre-allocate the vector with the correct size
std::vector<BlockPtr> sequenceBlocks(blockCount);
// Now traverse backwards and fill from the end
currentBlock = lastBlock;
size_t currentIndex = blockCount - 1;
while (currentBlock != nullptr && currentBlock->getBlockId() != KVCacheBlock::kCachedBlocksRootId)
{
sequenceBlocks[currentIndex--] = currentBlock;
currentBlock = currentBlock->getPrevBlockInSeq();
}
return sequenceBlocks;
}
} // namespace
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
KVCacheBlock::KVCacheBlock(IdType blockId, tk::KVCacheIndex blockIdx)
: mBlockId(blockId)
, mMemoryPoolBlockIndex{blockIdx}
, mRefCount(0)
, mSchedulingRefCount(0)
, mLookupNode{nullptr}
, mWindowSize{std::numeric_limits<int>::max()}
// sentinel: unattached; valid sizes are >= 1 or kRecurrentStates (-1)
, mIsPlaceholder{false}
, mIsFull{false}
, mPriority{executor::KvCacheRetentionConfig::kDefaultRetentionPriority}
, mDurationMs{std::nullopt}
, mExpirationTime{std::nullopt}
, mHash{0}
{
}
BlockPtr KVCacheBlock::createPlaceholder(IdType blockId)
{
// Use an out-of-range pool index as sentinel; the mIsPlaceholder flag gates
// getCacheBlockIndices to return nil so this index is never submitted to the GPU.
// The illegal value (INT32_MAX) ensures accidental use triggers an obvious OOB failure.
static constexpr auto kInvalidPoolIndex = std::numeric_limits<tk::KVCacheIndex::UnderlyingType>::max();
auto block = std::make_shared<KVCacheBlock>(blockId, tk::KVCacheIndex{kInvalidPoolIndex});
block->mIsPlaceholder = true;
return block;
}
bool KVCacheBlock::isPlaceholder() const
{
return mIsPlaceholder;
}
void KVCacheBlock::startScheduling()
{
mSchedulingRefCount = mRefCount;
}
KVCacheBlock::IdType KVCacheBlock::getBlockId() const
{
return mBlockId;
}
NextBlockMap KVCacheBlock::getNextBlocks() const
{
if (!mLookupNode)
{
return {};
}
NextBlockMap result;
for (auto const& [key, block] : mLookupNode->getChildKeyValues(mWindowSize))
{
result.emplace(key, block);
}
return result;
}
void KVCacheBlock::attachToLookupNode(radix_block_tree::LookupNodePtr node, int windowSize)
{
// Detach from any previous node first.
if (mLookupNode)
{
auto const wasCleared = mLookupNode->clearValue(mWindowSize);
TLLM_CHECK_WITH_INFO(wasCleared,
"attachToLookupNode: block %d expected prior lookup slot to be occupied (clearValue returned false)",
static_cast<int>(mBlockId));
}
// Assign fields AFTER trySetValue so local state is only updated on success.
auto const wasInserted = node->trySetValue(windowSize, shared_from_this(), /*overwrite=*/false);
TLLM_CHECK_WITH_INFO(wasInserted,
"attachToLookupNode: block %d found lookup slot already occupied by another block", static_cast<int>(mBlockId));
mLookupNode = std::move(node);
mWindowSize = windowSize;
}
void KVCacheBlock::detachFromLookupNode()
{
if (!mLookupNode)
{
return;
}
// clearValue triggers the cascade-prune up through empty ancestor nodes automatically.
auto const wasCleared = mLookupNode->clearValue(mWindowSize);
TLLM_CHECK_WITH_INFO(wasCleared,
"detachFromLookupNode: block %d expected lookup slot to be occupied (clearValue returned false)",
static_cast<int>(mBlockId));
mLookupNode = nullptr;
mWindowSize = std::numeric_limits<int>::max();
}
void KVCacheBlock::setAsRoot(radix_block_tree::LookupNodePtr rootNode, int windowSize)
{
mLookupNode = rootNode;
mWindowSize = windowSize;
// Store the root block itself in the root node so that direct children can find it
// via getPrevBlock() (root->getParentNode() returns nullptr, so the chain stops here).
auto const wasUpdated = rootNode->trySetValue(windowSize, shared_from_this(), /*overwrite=*/true);
TLLM_LOG_DEBUG("setAsRoot: block %d wired to root slot for windowSize=%d (wasUpdated=%d)",
static_cast<int>(mBlockId), windowSize, static_cast<int>(wasUpdated));
}
tk::KVCacheIndex::UnderlyingType KVCacheBlock::getMemoryPoolBlockIndex() const
{
return mMemoryPoolBlockIndex.get();
}
std::vector<MmKey> KVCacheBlock::getExtraKeys() const
{
return mBlockKey.extraKeys;
}
bool KVCacheBlock::isPrimary() const
{
return mMemoryPoolBlockIndex.isPrimary();
}
void KVCacheBlock::swapMemoryPoolBlockOffset(std::shared_ptr<KVCacheBlock> otherBlock)
{
std::swap(mMemoryPoolBlockIndex, otherBlock->mMemoryPoolBlockIndex);
}
void KVCacheBlock::incRefCount()
{
mRefCount++;
}
void KVCacheBlock::decRefCount()
{
TLLM_CHECK_WITH_INFO(
hasRefs(), "Can't remove link from block (id=%d) that is not allocated", static_cast<int>(mBlockId));
mRefCount--;
}
void KVCacheBlock::decSchedulingRefCount()
{
TLLM_CHECK_WITH_INFO(hasSchedulingRefs(), "Can't remove link from block that is not allocated");
mSchedulingRefCount--;
}
bool KVCacheBlock::hasRefs() const
{
return mRefCount > 0;
}
bool KVCacheBlock::isShared() const
{
// Block is considered shared if it has multiple references or is registered in the
// lookup tree (i.e., it is cached for reuse by future requests).
// Note: mCachedBlocksRoot also has mLookupNode set (via setAsRoot), but it is never
// placed in the eviction queue — enforced by an assertion in LRUEvictionPolicy::releaseBlock.
return mRefCount > 1 || mLookupNode != nullptr;
}
bool KVCacheBlock::hasSchedulingRefs() const
{
return mSchedulingRefCount > 0;
}
void KVCacheBlock::setBlockKey(BlockKey const& blockKey, bool isFull)
{
mBlockKey = blockKey;
mIsFull = isFull;
}
BlockKey KVCacheBlock::getBlockKey()
{
return mBlockKey;
}
void KVCacheBlock::setPriority(executor::RetentionPriority priority)
{
mPriority = priority;
}
executor::RetentionPriority KVCacheBlock::getPriority() const
{
return mPriority;
}
std::optional<std::chrono::milliseconds> KVCacheBlock::getDurationMs() const
{
return mDurationMs;
}
void KVCacheBlock::setDurationMs(std::optional<std::chrono::milliseconds> durationMs)
{
mDurationMs = durationMs;
}
void KVCacheBlock::setExpirationTime(std::optional<std::chrono::steady_clock::time_point::duration> expirationTime)
{
mExpirationTime = expirationTime;
}
std::optional<std::chrono::steady_clock::time_point::duration> KVCacheBlock::getExpirationTime() const
{
return mExpirationTime;
}
void KVCacheBlock::setHash(size_t hash)
{
mHash = hash;
}
void KVCacheBlock::setHash()
{
mHash = BlockKeyHasher()(mBlockKey, mPrevBlockInSeq ? mPrevBlockInSeq->getHash() : 0);
}
size_t KVCacheBlock::getHash() const
{
return mHash;
}
VecUniqueTokens const& KVCacheBlock::getUniqueTokens() const
{
return mBlockKey.uniqueTokens;
}
BlockPtr KVCacheBlock::getPrevBlock() const
{
if (!mLookupNode)
{
return nullptr;
}
auto parentNode = mLookupNode->getParentNode();
if (!parentNode)
{
// This block is the root (no parent node), so it has no parent block.
return nullptr;
}
auto optBlock = parentNode->getValue(mWindowSize);
return optBlock.value_or(nullptr);
}
BlockPtr const& KVCacheBlock::getPrevBlockInSeq() const
{
return mPrevBlockInSeq;
}
void KVCacheBlock::setPrevBlockInSeq(BlockPtr prevBlock)
{
mPrevBlockInSeq = std::move(prevBlock);
}
void KVCacheBlock::addNextBlock(BlockKey const& blockKey, BlockPtr block)
{
if (!mLookupNode)
{
return;
}
// Find existing child node or create a new one, then wire the block into it.
auto childNode = mLookupNode->findOrInsertChild(blockKey, mLookupNode);
// Only attach if there is no block already stored for this window size (matches old
// behaviour: addNextBlock was a no-op when the key already existed in mNextBlocks).
auto existing = childNode->getValue(mWindowSize);
if (!existing.has_value())
{
block->attachToLookupNode(childNode, mWindowSize);
}
}
std::tuple<bool, SizeType32, BlockPtr> KVCacheBlock::findMatchingBlock(
BlockKey const& blockKey, bool enablePartialReuse, bool copyOnPartialReuse) const
{
if (!mLookupNode || blockKey.uniqueTokens.empty())
{
return {false, 0, nullptr};
}
// Exact match
auto exactMatch = mLookupNode->findMatchingNode(blockKey);
if (exactMatch.has_value())
{
auto optBlock = exactMatch->node->getValue(mWindowSize);
if (optBlock.has_value() && *optBlock)
{
auto block = *optBlock;
return {!block->isFull(), static_cast<SizeType32>(blockKey.uniqueTokens.size()), block};
}
return {false, 0, nullptr};
}
// Partial match (sorted longest-first by findPartiallyMatchingNodes)
if (enablePartialReuse)
{
auto partialMatches = mLookupNode->findPartiallyMatchingNodes(blockKey);
for (auto const& match : partialMatches)
{
auto optBlock = match.node->getValue(mWindowSize);
if (!optBlock.has_value() || !(*optBlock))
{
continue;
}
auto block = *optBlock;
if (copyOnPartialReuse || (!block->hasRefs() && block->isLeaf()))
{
return {true, static_cast<SizeType32>(match.key.uniqueTokens.size()), block};
}
}
}
return {false, 0, nullptr};
}
void KVCacheBlock::freeLeafBlock()
{
// assure that this is a leaf block
TLLM_CHECK(isLeaf());
// Detach from the lookup tree; cascade pruning removes empty ancestor nodes.
detachFromLookupNode();
}
void KVCacheBlock::removeNextBlock(BlockKey const& blockKey)
{
if (mLookupNode)
{
// clearNode removes the child entry and fires cascade pruning upward if the child
// node becomes empty after the removal.
auto const wasCleared = mLookupNode->clearNode(blockKey);
if (!wasCleared)
{
TLLM_LOG_DEBUG("removeNextBlock: key not found for block %d; node may have been pruned already",
static_cast<int>(mBlockId));
}
}
}
// Iterative DFS over the subtree rooted at this block's children.
//
// Algorithm:
// 1. Push immediate children onto a stack and do DFS, collecting every
// reachable descendant in pre-order (parent before children).
// 2. Detach in *reverse* order (children before parents). This is
// required because detachFromLookupNode() triggers cascade pruning:
// when a node becomes empty (no value, no children) it is removed from
// its parent. If we detached in collection order (parents first), a
// parent node could be cascade-pruned away before we had a chance to
// look up its children in step 1. By detaching leaves first, cascade
// propagation only moves upward after all descendants are already gone.
void KVCacheBlock::detachDescendantsFromLookupTree()
{
if (!mLookupNode)
{
return;
}
std::vector<BlockPtr> descendants;
std::vector<BlockPtr> stack;
for (auto const& [key, block] : mLookupNode->getChildKeyValues(mWindowSize))
{
stack.push_back(block);
}
while (!stack.empty())
{
auto current = std::move(stack.back());
stack.pop_back();
if (current->mLookupNode)
{
for (auto const& [key, block] : current->mLookupNode->getChildKeyValues(current->mWindowSize))
{
stack.push_back(block);
}
}
TLLM_LOG_DEBUG("KVCacheBlock::detachDescendantsFromLookupTree - detaching block %d", current->getBlockId());
descendants.push_back(std::move(current));
}
// Detach leaves first so cascade-prune works correctly.
for (auto it = descendants.rbegin(); it != descendants.rend(); ++it)
{
(*it)->detachFromLookupNode();
}
}
void KVCacheBlock::freeBlockAndAllDescendants()
{
detachDescendantsFromLookupTree();
detachFromLookupNode();
}
bool KVCacheBlock::isFull() const
{
return mIsFull;
}
bool KVCacheBlock::isLeaf() const
{
return !mLookupNode || !mLookupNode->hasChildren();
}
// This function calculates the number of block a layer should have, given
// the total free memory and the window size of each layer.
// For example, if we have 1 layer of window size 1024, and 2 layer of window
// size 2048, and 3 layers of 4096.
// Each layer of window size 1024 should have
// 1024 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks.
// Each layer of window size 2048 should have
// 2048 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks.
// Each layer of window size 4096 should have
// 4096 / (1024 + 2048 * 2 + 4096 * 3) proportion of the total blocks.
// NOTE: Currently the use of this function is not used for
// BaseKVCacheManager::calculateMaxNumBlocks because the we want to first
// achieve identical performance as assuming all layers as full attention.
std::map<SizeType32, float> BlockManager::calculateWindowSizeToShare(
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers,
std::map<SizeType32, SizeType32> const& windowSizeToCacheSizePerToken)
{
if (windowSizeToLayers.size() == 1)
{
return {{windowSizeToLayers.begin()->first, 1.0f}};
}
std::map<SizeType32, float> windowSizeToContribution;
SizeType32 cacheSizePerTokenTotal
= std::accumulate(windowSizeToCacheSizePerToken.begin(), windowSizeToCacheSizePerToken.end(), SizeType32{0},
[](auto sum, auto const& windowSize) { return sum + windowSize.second; });
for (auto const& [windowSize, cacheSizePerToken] : windowSizeToCacheSizePerToken)
{
auto const cacheSizeWeight = static_cast<float>(cacheSizePerToken) / cacheSizePerTokenTotal;
windowSizeToContribution[windowSize] = cacheSizeWeight;
}
for (auto const& [windowSize, _] : windowSizeToLayers)
{
windowSizeToContribution.at(windowSize) *= windowSize;
}
auto const windowSizesTotalSum = std::accumulate(windowSizeToContribution.begin(), windowSizeToContribution.end(),
0.0, [](auto sum, auto const& windowSize) { return sum + windowSize.second; });
std::map<SizeType32, float> windowSizeToShare;
for (auto const& [windowSize, windowSizeSum] : windowSizeToContribution)
{
float const fraction = windowSizeSum / windowSizesTotalSum;
TLLM_CHECK(0.0f < fraction && fraction <= 1.0f);
windowSizeToShare[windowSize] = fraction;
}
auto total = std::accumulate(windowSizeToShare.begin(), windowSizeToShare.end(), 0.0f,
[](auto sum, auto const& windowSize) { return sum + windowSize.second; });
TLLM_CHECK(total == 1.0f);
return windowSizeToShare;
}
BlockManager::BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences,
std::shared_ptr<runtime::CudaStream> stream, SizeType32 maxSequenceLength, SizeType32 maxBeamWidth,
std::vector<SizeType32> const& maxAttentionWindowVec,
std::optional<TempAttentionWindowInputs> const& tempAttentionWindowInputs, nvinfer1::DataType dtype,
SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
std::optional<BaseAgentConfig> agentConfig, bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize,
SizeType32 indexerKCacheIndexHeadDim)
: mNumLayers{static_cast<SizeType32>(numKvHeadsPerLayer.size())}
, mTokensPerBlock{tokensPerBlock}
, mEventManager{std::move(eventManager)}
, mStream{stream}
, mCacheType{cacheType}
, mIsEnableIndexerKCache{enableIndexerKCache}
, mIndexerKCacheQuantBlockSize{indexerKCacheQuantBlockSize}
, mIndexerKCacheIndexHeadDim{indexerKCacheIndexHeadDim}
{
if (agentConfig.has_value())
mLoopbackAgent = makeLoopbackAgent("nixl", &agentConfig.value());
else
mLoopbackAgent = nullptr;
auto const uniqueWindowSizeToLayers
= BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, mNumLayers);
TLLM_CHECK_WITH_INFO(kvCacheConnectorManager == nullptr || uniqueWindowSizeToLayers.size() == 1,
"KV Cache Connector is not supported with multiple window sizes");
auto const numUniqueWindowSizes = static_cast<SizeType32>(uniqueWindowSizeToLayers.size());
mIsVariableWindow = numUniqueWindowSizes > 1;
mIsVariableGQA = std::unordered_set(numKvHeadsPerLayer.begin(), numKvHeadsPerLayer.end()).size() > 1;
mLayerToWindowSize.resize(mNumLayers);
for (auto const& [windowSize, layersWithWindowSize] : uniqueWindowSizeToLayers)
{
if (windowSize > maxSequenceLength)
{
TLLM_LOG_WARNING("[kv cache manager] window size %d is greater than max sequence length %d", windowSize,
maxSequenceLength);
}
for (auto& layerIdx : layersWithWindowSize)
{
mLayerToWindowSize.at(layerIdx) = windowSize;
}
auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize);
TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks...
mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer,
sizePerHead, tokensPerBlock, /*isSWA=*/windowSize < maxSequenceLength, allottedPrimaryBlocks,
allottedSecondaryBlocks, maxNumSequences, stream, onboardBlocks, cacheType, secondaryOffloadMinPriority,
mEventManager, enablePartialReuse, copyOnPartialReuse, kvCacheConnectorManager, mLookupTree, mLoopbackAgent,
enableIndexerKCache, indexerKCacheQuantBlockSize, indexerKCacheIndexHeadDim);
}
auto const numAllPools = getNumPools();
mAbsolutePoolToWindowSize.reserve(numAllPools);
mAbsolutePoolToRelativePoolIndex.reserve(numAllPools);
auto absolutePoolsOffset = SizeType32{0};
for (auto const& [windowSize, manager] : mWindowBlockManagers)
{
auto const numPools = manager.getNumPools();
for (auto i = 0; i < numPools; ++i)
{
mAbsolutePoolToWindowSize.push_back(windowSize);
mAbsolutePoolToRelativePoolIndex.push_back(i);
}
// SWA allocates blocks linearly, and we need as many blocks as full attention,
// where full attention has windowSize = maxSequenceLength.
auto const maxTokenNum = std::max(windowSize, maxSequenceLength) + sinkBubbleLength;
auto const temporaryAttentionWindow = manager.calculateTemporaryAttentionWindow(tempAttentionWindowInputs);
// Consider the temporaryAttentionWindow when allocating blocks.
// Current tempAttentionWindow calculation does not consider the
// concept of SWA right now at most occupying maxSequenceLength of
// blocks. So the calculation of maxToken + tempAttention will exceed
// maxSequenceLength. A temporary resolution here is to cap the
// calculation to maxSequenceLength. I will proceed with a follow-up
// MR to remove the tempAttentionWindow concept.
auto const maxBlocksPerSeq
= tc::ceilDiv(std::min(maxSequenceLength, maxTokenNum + temporaryAttentionWindow), tokensPerBlock);
auto const [allottedPrimaryBlocks, allottedSecondaryBlocks] = blocksPerWindow.at(windowSize);
mWindowSizeToMetadata[windowSize] = WindowSizeMetadata{allottedPrimaryBlocks, allottedSecondaryBlocks,
absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq, manager.getMaxNumBlocks(),
temporaryAttentionWindow, windowSize, manager.isSWA()};
TLLM_LOG_INFO(
"Max KV cache blocks per sequence: %d [window size=%d], tokens per block=%d, primary blocks=%d, secondary "
"blocks=%d, max sequence length=%d",
maxBlocksPerSeq, windowSize, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks,
maxSequenceLength);
TLLM_LOG_DEBUG(
"%s Metadata: %s", manager.getLogPrefix().c_str(), mWindowSizeToMetadata[windowSize].toString().c_str());
absolutePoolsOffset += numPools;
}
TLLM_CHECK_WITH_INFO(mWindowBlockManagers.size() == mWindowSizeToMetadata.size()
&& std::equal(mWindowBlockManagers.cbegin(), mWindowBlockManagers.cend(), mWindowSizeToMetadata.cbegin(),
mWindowSizeToMetadata.cend(),
[](auto const& window1, auto const& window2) { return window1.first == window2.first; }),
"Iteration order of window sizes between mWindowBlockManagers and mWindowSizeToMetadata *must* be ensured. "
"Maybe you tried changing either of them to an std::unordered_map?");
}
WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 windowSize,
std::vector<SizeType32> const& managedLayers, std::vector<SizeType32> const& numKvHeadsPerLayer,
SizeType32 sizePerHead, SizeType32 tokensPerBlock, bool isSWA, SizeType32 blocksInPrimaryPool,
SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream,
bool onboardBlocks, CacheType cacheType, std::optional<executor::RetentionPriority> secondaryOffloadMinPriority,
std::shared_ptr<KVCacheEventManager> eventManager, bool enablePartialReuse, bool copyOnPartialReuse,
std::shared_ptr<kv_connector::KvCacheConnectorManager> kvCacheConnectorManager,
radix_block_tree::UnifiedBlockTree& lookupTree, std::shared_ptr<kvc::BaseLoopbackAgent> loopbackAgent,
bool enableIndexerKCache, SizeType32 indexerKCacheQuantBlockSize, SizeType32 indexerKCacheIndexHeadDim)
: mDataType{dtype}
, mWindowSize{windowSize}
, mNumPrimaryBlocks{blocksInPrimaryPool}
, mNumSecondaryBlocks{blocksInSecondaryPool}
, mOnboardBlocks(onboardBlocks)
, mBufferManager{std::move(stream)}
, mSchedulingNumFreeBlocks{0}
, mTokensPerBlock{tokensPerBlock}
, mIsSWA{isSWA}
, mLookupTree{&lookupTree}
// Use an out-of-range pool index for the dummy root block; it is never submitted to the GPU.
// The illegal value (INT32_MAX) ensures accidental use triggers an obvious OOB failure.
, mCachedBlocksRoot{std::make_shared<KVCacheBlock>(KVCacheBlock::kCachedBlocksRootId,
tk::KVCacheIndex{std::numeric_limits<tk::KVCacheIndex::UnderlyingType>::max()})}
, mCacheType{cacheType}
, mEventManager(std::move(eventManager))
, mLoopbackAgent{loopbackAgent}
, mTransferManager{std::make_shared<KVCacheTransferManager>(mBufferManager, mLoopbackAgent)}
, mAllocTotalBlocks{0}
, mAllocNewBlocks{0}
, mReusedBlocks{0}
, mReusedUniqueBlocks{0}
, mMissedBlocks{0}
, mKVFactor{mCacheType == CacheType::kSELFKONLY ? 1 : 2}
, mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)}
, mReusedTokens{0.0}
, mTotalInputTokens{0.0}
, mEnablePartialReuse{enablePartialReuse}
, mCopyOnPartialReuse{copyOnPartialReuse}
, mKvCacheConnectorManager{std::move(kvCacheConnectorManager)}
, mEnableIndexerKCache{enableIndexerKCache}
, mIndexerKCacheQuantBlockSize{indexerKCacheQuantBlockSize}
, mIndexerKCacheIndexHeadDim{indexerKCacheIndexHeadDim}
{
std::map<SizeType32, SizeType32> numLayersPerPool;
for (auto const layerIdx : managedLayers)
{
auto const& layerIndexWithinPool = numLayersPerPool[numKvHeadsPerLayer.at(layerIdx)]++;
mLayerToIndexWithinPool[layerIdx] = layerIndexWithinPool;
}
auto numEltsPerContainer = getNumEltsPerContainer();
#ifdef ENABLE_FP4
if (numEltsPerContainer == 2)
{
TLLM_CHECK_WITH_INFO(sizePerHead % 2 == 0, "sizePerHead must be divisible by 2 for 4-bit KV cache.");
}
#endif
size_t poolIndex = 0;
for (auto const [numKvHeads, numLayers] : numLayersPerPool)
{
for (auto const layerIdx : managedLayers)
{
if (numKvHeadsPerLayer.at(layerIdx) == numKvHeads)
{
mLayerToPoolIndex[layerIdx] = poolIndex;
}
}
mPools.emplace_back(numLayers, mKVFactor, numKvHeads, sizePerHead / numEltsPerContainer, tokensPerBlock);
++poolIndex;
}
#ifdef ENABLE_FP4
// TODO(miovine): make the block size configurable. Should we have an additional argument
// to specify FP4 related parameters (scale dtypes, etc)? This can also be passed
// in the constructor.
constexpr SizeType32 kQuantBlockSizeNVFP4 = 16;
if (dtype == nvinfer1::DataType::kFP4)
{
createBlockScalePools(kQuantBlockSizeNVFP4);
}
#endif
if (mEnableIndexerKCache)
{
createIndexerKCachePools();
}
// Create free blocks
mAllBlocksById.reserve(blocksInPrimaryPool + blocksInSecondaryPool);
for (KVCacheBlock::IdType blockId = 0; blockId < blocksInPrimaryPool; ++blockId)
{
mAllBlocksById.emplace_back(std::make_shared<KVCacheBlock>(blockId, tk::KVCacheIndex{blockId, false}));
}
for (KVCacheBlock::IdType blockId = 0; blockId < blocksInSecondaryPool; ++blockId)
{
mAllBlocksById.emplace_back(
std::make_shared<KVCacheBlock>(blocksInPrimaryPool + blockId, tk::KVCacheIndex{blockId, true}));
}
mAllocatedBlocksPerSeq.reserve(maxNumSequences);
mEvictionPolicy = std::make_shared<LRUEvictionPolicy>();
mEvictionPolicy->initialize(
mAllBlocksById, {blocksInPrimaryPool, blocksInSecondaryPool}, secondaryOffloadMinPriority);
if (mEventManager)
{
mEventManager->enqueueCreatedEvent({blocksInPrimaryPool, blocksInSecondaryPool}, mWindowSize);
}
// Wire the dummy root block into the shared lookup tree so that direct children
// can navigate to it via getPrevBlock() and blockInRadixTree() returns true for them.
mCachedBlocksRoot->setAsRoot(mLookupTree->getRoot(), mWindowSize);
}
WindowBlockManager::~WindowBlockManager()
{
float reusedUniqueBlocksPercentage = mReusedUniqueBlocks == 0 || mAllocTotalBlocks == 0
? 0
: static_cast<float>(mReusedUniqueBlocks) / static_cast<float>(mAllocNewBlocks) * 100;
float cacheHitRate = mReusedBlocks == 0
? 0
: static_cast<float>(mReusedBlocks) / (static_cast<float>(mReusedBlocks + mMissedBlocks));
TLLM_LOG_DEBUG("%s - total allocated blocks: %lu ", mLogPrefix.c_str(), mAllocTotalBlocks);
TLLM_LOG_DEBUG("%s - allocated new blocks: %lu ", mLogPrefix.c_str(), mAllocNewBlocks);
TLLM_LOG_DEBUG("%s - missed blocks: %lu ", mLogPrefix.c_str(), mMissedBlocks);
TLLM_LOG_DEBUG("%s - reused blocks: %lu ", mLogPrefix.c_str(), mReusedBlocks);
TLLM_LOG_DEBUG("%s - reused unique blocks: %lu ", mLogPrefix.c_str(), mReusedUniqueBlocks);
TLLM_LOG_DEBUG(
"%s - reused unique blocks percentage (%%): %.2f ", mLogPrefix.c_str(), reusedUniqueBlocksPercentage);
TLLM_LOG_DEBUG("%s - cache hit rate: %.2f ", mLogPrefix.c_str(), cacheHitRate);
TLLM_LOG_DEBUG("%s - reused tokens: %.0f ", mLogPrefix.c_str(), mReusedTokens);
TLLM_LOG_DEBUG("%s - reused tokens percentage (%%): %.2f ", mLogPrefix.c_str(),
100.0 * mReusedTokens / mTotalInputTokens);
}
bool BlockManager::verifyQueueIntegrity(SizeType32 windowSize)
{
return mWindowBlockManagers.at(windowSize).verifyQueueIntegrity();
}
bool WindowBlockManager::verifyQueueIntegrity()
{
return mEvictionPolicy->verifyQueueIntegrity();
}
void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest)
{
constexpr int beamIdx = 0; // no need to consider more than one beam for input tokens
// Iterate in descending window-size order (largest/full-attention windows first).
// This guarantees that the Stored event for the full-attention window is committed
// before flushRemovedEvents fires for SWA windows, preserving the per-window
// ordering guarantee: Removed events precede the Stored event for the same window,
// and Stored(full) is not interleaved with Removed(SWA).
for (auto it = mWindowBlockManagers.rbegin(); it != mWindowBlockManagers.rend(); ++it)
{
auto& [windowSize, manager] = *it;
auto cacheBlockIds = sequence.getCacheBlockIds(windowSize);
auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx);
TLLM_LOG_DEBUG("storeContextBlocks for request %lu on window %d with %d unique tokens", llmRequest.mRequestId,
windowSize, uniqueTokens.size());
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
(void) manager.storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}
}
void WindowBlockManager::createBlockScalePools(SizeType32 quantBlockSize)
{
SizeType32 const numEltsPerContainer = getNumEltsPerContainer();
SizeType32 numPools = mPools.size();
for (SizeType32 i = 0; i < numPools; ++i)
{
auto& kvPool = mPools[i];
if (kvPool.containsIndexerKCache || kvPool.containsBlockScales)
{
continue;
}
TLLM_CHECK_WITH_INFO((kvPool.sizePerHead * numEltsPerContainer) % quantBlockSize == 0,
"Cannot use FP4 quantization since kvPool.sizePerHead is not divisible by FP4 quantBlockSize.");
auto blockScaleSizePerHead = kvPool.sizePerHead * numEltsPerContainer / quantBlockSize;
mPools.emplace_back(kvPool.numLayers, kvPool.kvFactor, kvPool.numKvHeads, blockScaleSizePerHead,
kvPool.tokensPerBlock,
/*primaryPool=*/nullptr,
/*secondaryPool=*/nullptr,
/*containsBlockScales=*/true,
/*containsIndexerKCache=*/false);
}
}
void WindowBlockManager::createIndexerKCachePools()
{
SizeType32 numPools = mPools.size();
for (SizeType32 i = 0; i < numPools; ++i)
{
auto& kvPool = mPools[i];
if (kvPool.containsIndexerKCache || kvPool.containsBlockScales)
{
continue;
}
SizeType32 scaleSize = mIndexerKCacheIndexHeadDim / mIndexerKCacheQuantBlockSize * 4;
mPools.emplace_back(kvPool.numLayers, kvPool.kvFactor, 1, scaleSize + mIndexerKCacheIndexHeadDim,
kvPool.tokensPerBlock,
/*primaryPool=*/nullptr,
/*secondaryPool=*/nullptr,
/*containsBlockScales=*/false,
/*containsIndexerKCache=*/true);
}
}
void BlockManager::allocatePools(bool useUvm)
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.allocatePools(useUvm);
}
}
void WindowBlockManager::allocatePools(bool useUvm)
{
constexpr nvinfer1::DataType kScaleDtypeNVFP4 = nvinfer1::DataType::kFP8;
// Allocate a memory pool backing the blocks for each numKvHeads
// TODO(oargov): allocate pools in a single buffer and split it, to avoid fragmentation
for (auto& pool : mPools)
{
auto blockSize = pool.blockSize;
auto poolDtype = pool.containsBlockScales ? kScaleDtypeNVFP4 : mDataType;
#ifdef ENABLE_FP4
auto const poolIsFP4 = poolDtype == nvinfer1::DataType::kFP4;
#else
auto const poolIsFP4 = false;
#endif
if (poolIsFP4)
{
poolDtype = nvinfer1::DataType::kINT8;
}
if (pool.containsIndexerKCache)
{
poolDtype = nvinfer1::DataType::kUINT8;
}
nvinfer1::Dims cacheShape;
cacheShape = ITensor::makeShape({mNumPrimaryBlocks, pool.numLayers, mKVFactor, blockSize});
TLLM_LOG_DEBUG("[%s] Allocating primary pool with %d blocks for %d layers with %d kv heads", mLogPrefix.c_str(),
mNumPrimaryBlocks, pool.numLayers, pool.numKvHeads);
if (useUvm)
pool.primaryPtr = BufferManager::managed(cacheShape, poolDtype);
else
pool.primaryPtr = mBufferManager.gpuSync(cacheShape, poolDtype);
if (mNumSecondaryBlocks > 0)
{
nvinfer1::Dims const cacheShapeOffload
= ITensor::makeShape({mNumSecondaryBlocks, pool.numLayers, mKVFactor, blockSize});
TLLM_LOG_DEBUG("[%s] Allocating secondary pool with %d blocks for %d layers with %d kv heads",
mLogPrefix.c_str(), mNumSecondaryBlocks, pool.numLayers, pool.numKvHeads);
pool.secondaryPtr = BufferManager::pinned(cacheShapeOffload, poolDtype);
}
}
}
void BlockManager::releasePools()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.releasePools();
}
}
void WindowBlockManager::releasePools()
{
for (auto& pool : mPools)
{
if (pool.primaryPtr)
{
pool.primaryPtr->release();
}
if (pool.secondaryPtr)
{
pool.secondaryPtr->release();
}
}
mBufferManager.getStream().synchronize();
mBufferManager.memoryPoolTrimTo(0);
}
void BlockManager::startScheduling()
{
for (auto& [_, manager] : mWindowBlockManagers)
{
manager.startScheduling();
}
}
void WindowBlockManager::startScheduling()
{
mSchedulingNumFreeBlocks = mEvictionPolicy->getNumFreeBlocks(kPrimaryLevel);
for (auto& [requestId, slotAllocatedBlocks] : mAllocatedBlocksPerSeq)
{
for (auto& allocatedBlock : slotAllocatedBlocks)
{
allocatedBlock->startScheduling();
}
}
}
void WindowBlockManager::freeLeafBlock(BlockPtr const& block)
{
// The eviction policy needs blocks to still be linked to their old parents when they're reclaimed.
// This is so it can check if the parent should be queued for eviction.
block->freeLeafBlock();
}
void WindowBlockManager::freeChildren(BlockPtr const& block)
{
// Tell event manager we are freeing block
if (mEventManager && blockInRadixTree(block))
{
mEventManager->enqueueRemovedEvent(block, mWindowSize);
}
// Free block and all it's descendants from radix tree
block->freeBlockAndAllDescendants();
}
BlockPtr WindowBlockManager::getFreeBlock(GenerationRequest& sequence, executor::RetentionPriority priority,
std::optional<std::chrono::milliseconds> durationMs, executor::KvCacheTransferMode mode,
std::string const& directory)
{
// eviction policy get free primary block
auto [block, canOffload] = mEvictionPolicy->getFreeBlock(kPrimaryLevel);
if (block->getUniqueTokens().empty())
{
++mAllocNewBlocks;
}
++mAllocTotalBlocks;
// Offloading is an option only when these conditions are met:
// 1. Block contains state (evidenced by presence of tokens)
// 2. Eviction policy indicated block can be offloaded
// 3. At least one free block in secondary memory
// 4. Onboarding is enabled (allowing block to be brought back into primary)
if (!block->getUniqueTokens().empty() && canOffload && mEvictionPolicy->getNumFreeBlocks(kSecondaryLevel) > 0
&& mOnboardBlocks)
{
// Offload block in primary memory before repurposing
auto offloadBlock = std::get<0>(mEvictionPolicy->getFreeBlock(kSecondaryLevel));
// Claim both blocks BEFORE the swap so getCacheLevel() still reflects the
// actual free-queue each iterator belongs to. After swapMemoryPoolBlockOffset()
// isPrimary() is inverted for both blocks, so calling claimBlock() post-swap
// would make it erase from the wrong std::list -- undefined behaviour.
// This ordering matches WindowBlockManager::offloadBlock().
mEvictionPolicy->claimBlock(block); // block is PRIMARY -> erases from primary queue
mEvictionPolicy->claimBlock(offloadBlock); // offloadBlock is SECONDARY -> erases from secondary queue
mTransferManager->offload(block, offloadBlock, mPools, 0, mode, directory);
// swap linear block offsets (i.e. make block the offload block)