@@ -491,11 +491,14 @@ class BlockManager
491491
492492 void replaceSharedBlock (GenerationRequest& sequence, SizeType32 blockIdx);
493493
494+ // ! \brief Get the ids of all newly allocated (not reused) blocks for the sequence.
495+ std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds (GenerationRequest const & sequence) const ;
496+
494497 // ! \brief Release blocks of the sequence. Store blocks for reuse if llmReqeust is provided.
495498 void releaseBlocks (GenerationRequest& sequence, OptionalRef<LlmRequest const > llmRequest = std::nullopt );
496499
497500 // ! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
498- void schedulingReleaseBlocks (GenerationRequest& sequence );
501+ void schedulingReleaseBlocks (LlmRequest::RequestIdType requestId );
499502
500503 // ! \brief Release last block in the sequence
501504 void releaseLastBlock (GenerationRequest& sequence);
@@ -658,6 +661,11 @@ class BlockManager
658661
659662 [[nodiscard]] static bool blockInRadixTree (BlockPtr const & block);
660663
664+ [[nodiscard]] bool isEnableHashKey () const
665+ {
666+ return mEnableHashKey ;
667+ }
668+
661669private:
662670 // ! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
663671 void addBlockToBeam (BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@@ -849,6 +857,7 @@ class BaseKVCacheManager
849857 virtual void rewindKVCache (LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) = 0;
850858
851859 [[nodiscard]] virtual GenerationRequest const & getSequence (LlmRequest::RequestIdType requestId) const = 0;
860+ [[nodiscard]] virtual GenerationRequest& getSequence (LlmRequest::RequestIdType requestId) = 0;
852861
853862 [[nodiscard]] virtual bool isCrossKv () const = 0;
854863
@@ -872,6 +881,10 @@ class BaseKVCacheManager
872881 std::vector<LlmRequest::RequestIdType> const & requestIds) const
873882 = 0;
874883
884+ [[nodiscard]] virtual std::vector<KVCacheBlock::IdType> getNewlyAllocatedBlockIds (
885+ LlmRequest::RequestIdType requestId) const
886+ = 0;
887+
875888 [[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const = 0;
876889 [[nodiscard]] virtual SizeType32 getPoolLayerIdx (SizeType32 layer_idx) const = 0;
877890
@@ -904,6 +917,8 @@ class BaseKVCacheManager
904917 // / @param outputLength The number of output tokens in each sequence in the batch.
905918 // / @return SizeType32 A number of sequences per batch.
906919 [[nodiscard]] virtual SizeType32 getMaxCapacityBatchSize (SizeType32 inputLength, SizeType32 outputLength) const = 0;
920+
921+ [[nodiscard]] virtual CacheType getCacheType () const = 0;
907922};
908923
909924class KVCacheManager : public BaseKVCacheManager
@@ -935,7 +950,7 @@ class KVCacheManager : public BaseKVCacheManager
935950 SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional<SizeType32> maxSequenceLength,
936951 bool enableBlockReuse = true , bool onboardBlocks = true , CacheType cacheType = CacheType::kSELF ,
937952 std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
938- std::shared_ptr<KVCacheEventManager> eventManager = nullptr );
953+ std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enableHashKey = false );
939954
940955 KVCacheManager (SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
941956 SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
@@ -1100,12 +1115,18 @@ class KVCacheManager : public BaseKVCacheManager
11001115 void rewindKVCache (LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override ;
11011116
11021117 [[nodiscard]] GenerationRequest const & getSequence (LlmRequest::RequestIdType requestId) const override ;
1118+ [[nodiscard]] GenerationRequest& getSequence (LlmRequest::RequestIdType requestId) override ;
11031119
11041120 [[nodiscard]] bool isCrossKv () const override
11051121 {
11061122 return mBlockManager .getCacheType () == CacheType::kCROSS ;
11071123 }
11081124
1125+ [[nodiscard]] CacheType getCacheType () const override
1126+ {
1127+ return mBlockManager .getCacheType ();
1128+ }
1129+
11091130 // ! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector.
11101131 // ! \details Only full blocks are considered.
11111132 [[nodiscard]] std::optional<BlockKey> findNewContextBlock (
@@ -1148,6 +1169,8 @@ class KVCacheManager : public BaseKVCacheManager
11481169 std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds (
11491170 std::vector<LlmRequest::RequestIdType> const & requestIds) const override ;
11501171
1172+ std::vector<SizeType32> getNewlyAllocatedBlockIds (LlmRequest::RequestIdType requestId) const override ;
1173+
11511174 runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const override ;
11521175
11531176 SizeType32 getPoolLayerIdx (SizeType32 layer_idx) const override
@@ -1219,6 +1242,8 @@ class KVCacheManager : public BaseKVCacheManager
12191242 bool mEnableHashKey ;
12201243 // Whether use one more block for each sequence
12211244 bool mUseOneMoreBlock ;
1245+ // Mutex to protect access to mSequences
1246+ mutable std::mutex mSequencesMtx ;
12221247 // buffers for static tensors, will be created after allocating pools
12231248 runtime::ITensor::SharedPtr mBlockPoolPointers ;
12241249 runtime::ITensor::SharedPtr mLayerToPoolMapping ;
0 commit comments