@@ -130,14 +130,17 @@ struct WindowSizeMetadata
130130 SizeType32 temporaryAttentionWindow; // Temporary kv cache length per sequence.
131131 // Only needed when chunked context + sliding window attention are used
132132 // together. And it should only be considered when allocating blocks.
133+ SizeType32 windowSize;
134+ bool isSWA;
133135
134136 std::string toString ()
135137 {
136138 return tensorrt_llm::common::fmtstr (
137139 " WindowSizeMetadata{ .allottedPrimaryBlocks=%d, .allottedSecondaryBlocks=%d, .absolutePoolsOffset=%d, "
138- " .numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d }" ,
140+ " .numPools=%d, .maxTokenNum=%d, .maxBlocksPerSeq=%d, .maxNumBlocks=%d, .temporaryAttentionWindow=%d, "
141+ " .windowSize=%d, .isSWA=%d }" ,
139142 allottedPrimaryBlocks, allottedSecondaryBlocks, absolutePoolsOffset, numPools, maxTokenNum, maxBlocksPerSeq,
140- maxNumBlocks, temporaryAttentionWindow);
143+ maxNumBlocks, temporaryAttentionWindow, windowSize, isSWA );
141144 }
142145};
143146
@@ -512,6 +515,8 @@ class GenerationRequest
512515 executor::KvCacheRetentionConfig mKvCacheRetentionConfig ;
513516 // Number of front blocks removed from the sequence
514517 SizeType32 mNumFrontBlocksRemoved ;
518+ // Set of used blocks by the sequence
519+ std::set<KVCacheBlock::IdType> mUsedBlocks ;
515520};
516521
517522// attach metadata to a pool pointer
@@ -628,15 +633,15 @@ class WindowBlockManager
628633 void releaseLastBlock (GenerationRequest& sequence);
629634
630635 // ! \brief Detach front block from the sequence
631- void detachFrontBlock (GenerationRequest& sequence, bool isEnableBlockReuse );
636+ void detachFrontBlock (GenerationRequest& sequence);
632637
633638 // ! \brief Add/detach block(s) to/from the sequence if needed
634639 // ! \details When we need a new block, we add it. For sliding window
635640 // ! attention (SWA), when a block goes out-of-window (OOW), we detach it
636- // ! and store it if reuse is enabled. If this called in the first step of
637- // ! the generation phase, we may detach more than a single block since
638- // ! there may be more than one context block that goes OOW.
639- void adjustBlocksIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse );
641+ // ! If this called in the first step of the generation phase, we may detach
642+ // ! more than a single block since there may be more than one context block
643+ // ! that goes OOW.
644+ void adjustBlocksIfNeeded (GenerationRequest& sequence);
640645
641646 [[nodiscard]] SizeType32 getWindowSize () const noexcept
642647 {
@@ -763,7 +768,7 @@ class WindowBlockManager
763768
764769 // ! \brief Bring offloaded block from secondary to primary memory.
765770 // ! \details Does nothing if block is already in primary memory.
766- void onboardBlock (BlockPtr const & offloadBlock,
771+ void onboardBlock (GenerationRequest& sequence, BlockPtr const & offloadBlock,
767772 executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const & directory = " " );
768773
769774 // ! \brief Bring block from primary to secondary memory.
@@ -826,6 +831,23 @@ class WindowBlockManager
826831 // ! \brief Unpin blocks by starting from a block id and walking prev pointers.
827832 void unpinBlocksById (KVCacheBlock::IdType blockId);
828833
834+ void initializeSequenceStorageValidity (LlmRequest::RequestIdType requestId)
835+ {
836+ mIsValidStoreForReuseSequence [requestId] = true ;
837+ }
838+
839+ void releaseSequenceStorageValidity (LlmRequest::RequestIdType requestId)
840+ {
841+ mIsValidStoreForReuseSequence .erase (requestId);
842+ }
843+
844+ // ! \brief Return whether this sequence is valid for store for reuse
845+ [[nodiscard]] bool isSequenceValidForStoreForReuse (LlmRequest::RequestIdType requestId) const
846+ {
847+ TLLM_CHECK_WITH_INFO (mIsValidStoreForReuseSequence .count (requestId) > 0 , " Sequence should be bookkeeped" );
848+ return mIsValidStoreForReuseSequence .at (requestId);
849+ }
850+
829851private:
830852 // ! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
831853 void addBlockToBeam (BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -842,18 +864,17 @@ class WindowBlockManager
842864 executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const & directory = " " );
843865
844866 // ! \brief Free block and all it's descendants. This makes block a claimed leaf block.
845- void freeChildren (BlockPtr const & block, executor::RetentionPriority priority,
846- std::optional<std::chrono::milliseconds> durationMs);
867+ void freeChildren (BlockPtr const & block);
847868
848869 // ! \brief Find block least likely to be reused, free it if necessary and return.
849- [[nodiscard]] BlockPtr getFreeBlock (
870+ // ! \param sequence Sequence which the free block is allocated for
871+ [[nodiscard]] BlockPtr getFreeBlock (GenerationRequest& sequence,
850872 executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority ,
851873 std::optional<std::chrono::milliseconds> durationMs = std::nullopt ,
852874 executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const & directory = " " );
853875
854- // ! \brief Free block from previous block and claim it from free blocks list.
855- void claimLeafBlock (BlockPtr const & block, std::optional<executor::RetentionPriority> priority = std::nullopt ,
856- std::optional<std::chrono::milliseconds> durationMs = std::nullopt );
876+ // ! \brief Calls KVCacheBlock::freeLeafBlock to remove block from search tree.
877+ void freeLeafBlock (BlockPtr const & block);
857878
858879 // ! \brief For FP4 quantization. Creates pool objects for FP4 block scalars.
859880 void createBlockScalePools (SizeType32 blockSize);
@@ -933,6 +954,14 @@ class WindowBlockManager
933954
934955 // Mutex for the cached blocks root
935956 std::mutex mCachedBlocksRootMutex ;
957+
958+ // Record which sequence is using the block
959+ std::map<KVCacheBlock::IdType, LlmRequest::RequestIdType> mBlockToSequence ;
960+ // Record whether a sequence has all blocks held valid.
961+ // The boolean value is set to true upon first encounter of a new sequence.
962+ // It may be invalidated to false when other sequence acquires a block that
963+ // is used by another sequence.
964+ std::map<LlmRequest::RequestIdType, bool > mIsValidStoreForReuseSequence ;
936965};
937966
938967class BlockManager
@@ -1008,7 +1037,7 @@ class BlockManager
10081037
10091038 // ! \brief Bring block from primary to secondary memory for window size.
10101039 // ! \details Does nothing if block is already in primary memory.
1011- void onboardBlock (BlockPtr const & offloadBlock, SizeType32 windowSize,
1040+ void onboardBlock (GenerationRequest& sequence, BlockPtr const & offloadBlock, SizeType32 windowSize,
10121041 executor::KvCacheTransferMode mode = executor::KvCacheTransferMode::DRAM, std::string const & directory = " " );
10131042
10141043 // ! \brief Bring block from primary to secondary memory for window size.
@@ -1239,10 +1268,52 @@ class BlockManager
12391268 // ! \brief Add/detach block(s) to/from the sequence if needed
12401269 // ! \details When we need a new block, we add it. For sliding window
12411270 // ! attention (SWA), when a block goes out-of-window (OOW), we detach it
1242- // ! and store it if reuse is enabled. If this called in the first step of
1243- // ! the generation phase, we may detach more than a single block since
1244- // ! there may be more than one context block that goes OOW.
1245- void adjustBlocksIfNeeded (GenerationRequest& sequence, bool isEnableBlockReuse);
1271+ // ! If this called in the first step of the generation phase, we may
1272+ // ! detach more than a single block since there may be more than one
1273+ // ! context block that goes OOW.
1274+ void adjustBlocksIfNeeded (GenerationRequest& sequence);
1275+
1276+ // ! \brief Return whether the sequence is already managed by the block manager
1277+ [[nodiscard]] bool isSequenceHeld (LlmRequest::RequestIdType requestId) const
1278+ {
1279+ return mManagedSequences .count (requestId) > 0 ;
1280+ }
1281+
1282+ // ! \brief Add a sequence to the managed sequences
1283+ // ! \details Take the sequence into account for the manager. Initialize
1284+ // ! sequence storage validity under all window sizes.
1285+ void holdSequence (LlmRequest::RequestIdType requestId)
1286+ {
1287+ mManagedSequences .insert (requestId);
1288+ for (auto const & [windowSize, metadata] : mWindowSizeToMetadata )
1289+ {
1290+ mWindowBlockManagers .at (windowSize).initializeSequenceStorageValidity (requestId);
1291+ }
1292+ }
1293+
1294+ // ! \brief Remove a sequence from the managed sequences.
1295+ // ! \details Remove sequence from the managed sequences and remove sequence
1296+ // ! storage
1297+ void releaseSequence (LlmRequest::RequestIdType requestId)
1298+ {
1299+ mManagedSequences .erase (requestId);
1300+ for (auto const & [windowSize, metadata] : mWindowSizeToMetadata )
1301+ {
1302+ mWindowBlockManagers .at (windowSize).releaseSequenceStorageValidity (requestId);
1303+ }
1304+ }
1305+
1306+ // ! \brief Return whether the sequence is still valid for store-for-reuse
1307+ // ! regarding the specific window size.
1308+ // ! \details Currently this utility function is only used under
1309+ // ! kvCacheManagerTest.cpp. Checking for store-for-reuse for each window
1310+ // ! size is done in an iterating fashion under BlockManager::releaseBlocks.
1311+ bool isSequenceValidForStoreForReuse (LlmRequest::RequestIdType requestId, SizeType32 windowSize) const
1312+ {
1313+ TLLM_CHECK_WITH_INFO (
1314+ mWindowBlockManagers .count (windowSize) > 0 , " Querying window size is not found under mWindowBlockManager" );
1315+ return mWindowBlockManagers .at (windowSize).isSequenceValidForStoreForReuse (requestId);
1316+ }
12461317
12471318private:
12481319 [[nodiscard]] WindowBlockManager const & windowManagerByLayer (SizeType32 layerIdx) const
@@ -1278,6 +1349,8 @@ class BlockManager
12781349 std::vector<SizeType32> mLayerToWindowSize ;
12791350 std::vector<SizeType32> mAbsolutePoolToWindowSize ;
12801351 std::vector<SizeType32> mAbsolutePoolToRelativePoolIndex ;
1352+ // Record what sequences are currently managed by the block manager
1353+ std::set<LlmRequest::RequestIdType> mManagedSequences ;
12811354};
12821355
12831356struct OffsetTableDimensions
0 commit comments