1717
1818#pragma once
1919
20+ #include " cacheTransBuffer.h"
2021#include " dataTransceiver.h"
21- #include " tensorrt_llm/batch_manager/cacheTransBuffer.h"
2222#include " tensorrt_llm/batch_manager/kvCacheManager.h"
2323#include " tensorrt_llm/batch_manager/kvCacheUtils.h"
2424#include " tensorrt_llm/common/envUtils.h"
2525#include " tensorrt_llm/common/logger.h"
26+ #include " tensorrt_llm/executor/cacheCommunicator.h"
2627#include " tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
2728#include " tensorrt_llm/executor/dataTransceiverState.h"
2829#include " tensorrt_llm/runtime/bufferManager.h"
@@ -60,13 +61,54 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
6061
6162BlockRange getBlockRangeForReceiving (BaseKVCacheManager* cacheManager, LlmRequest const & llmRequest);
6263
63- // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
64- // parallel topology is completely identical, making it the preferred method.
65- class CacheFormatter final : public IOFormatter
64+ // Used to support the cache transmission with different layouts and different protocols.
65+ class BaseCacheFormatter
6666{
6767public:
68+ using SizeType32 = tensorrt_llm::runtime::SizeType32;
6869 using CacheState = executor::kv_cache::CacheState;
6970
71+ virtual void formatOutput (LlmRequest const & llmRequest,
72+ std::vector<executor::kv_cache::Connection const *> const & connections, CacheState const & selfConfig,
73+ SizeType32 selfIdx, CacheState const & destConfig, runtime::BufferManager const & bufferManager)
74+ = 0;
75+
76+ virtual void formatInput (LlmRequest const & llmRequest,
77+ std::vector<executor::kv_cache::Connection const *> const & connections, CacheState const & selfConfig,
78+ SizeType32 selfIdx, CacheState const & destConfig, runtime::BufferManager const & bufferManager)
79+ = 0;
80+
81+ // / @brief Determine whether the sender is applicable to the source and target.
82+ // / @param selfConfig Source data arrangement.
83+ // / @param destConfig Target data arrangement.
84+ // / @return Whether the sender is applicable to the source and target.
85+ [[nodiscard]] virtual bool inquireSupport (CacheState const & selfConfig, CacheState const & destConfig) const = 0;
86+
87+ // / @brief Obtain the indies of the counterparts that need to be actually communicated with.
88+ // / @param selfConfig Source data arrangement.
89+ // / @param selfIdx The sequential index of the current executor process within the entire parallel group.
90+ // / @param destConfig Target data arrangement.
91+ // / @return The indies of the counterparts.
92+ [[nodiscard]] virtual std::vector<SizeType32> getCounterparts (
93+ CacheState const & selfConfig, SizeType32 selfIdx, CacheState const & destConfig) const
94+ = 0;
95+
96+ [[nodiscard]] virtual BaseKVCacheManager* getCacheManager () const noexcept = 0;
97+
98+ [[nodiscard]] virtual std::vector<executor::kv_cache::Connection const *> pickRecvConnections (
99+ std::vector<executor::kv_cache::Connection const *> const & connections, CacheState const & selfConfig,
100+ SizeType32 selfIdx, CacheState const & destConfig) const
101+ = 0;
102+
103+ // / @brief Destructor.
104+ virtual ~BaseCacheFormatter () = default ;
105+ };
106+
107+ // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
108+ // parallel topology is completely identical, making it the preferred method.
109+ class CacheFormatter final : public BaseCacheFormatter
110+ {
111+ public:
70112 CacheFormatter (BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager)
71113 : mCacheManager {cacheManager}
72114 , mCacheTransBufferManager {cacheTransBufferManager}
@@ -91,7 +133,7 @@ class CacheFormatter final : public IOFormatter
91133 return executor::kv_cache::targetIRanks (destConfig, selfConfig, selfIdx).mIRanks ;
92134 }
93135
94- BaseKVCacheManager* getCacheManager () const noexcept
136+ [[nodiscard]] BaseKVCacheManager* getCacheManager () const noexcept override
95137 {
96138 return mCacheManager ;
97139 }
@@ -102,11 +144,12 @@ class CacheFormatter final : public IOFormatter
102144 SizeType32 selfIdx, CacheState const & destConfig) const override ;
103145
104146private:
105- BaseKVCacheManager* mCacheManager {};
106-
147+ BaseKVCacheManager* mCacheManager ;
107148 CacheTransBufferManager* mCacheTransBufferManager ;
108-
109149 KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath ()};
110150};
111151
152+ std::unique_ptr<BaseCacheFormatter> createCacheFormatter (
153+ BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA = false );
154+
112155} // namespace tensorrt_llm::batch_manager::kv_cache_manager
0 commit comments