3737#include " tensorrt_llm/batch_manager/cacheFormatter.h"
3838#include " tensorrt_llm/batch_manager/cacheTransceiver.h"
3939#include " tensorrt_llm/batch_manager/contextProgress.h"
40- #include " tensorrt_llm/batch_manager/dataTransceiverImpl.h"
4140#include " tensorrt_llm/batch_manager/kvCacheManager.h"
41+ #include " tensorrt_llm/batch_manager/kvCacheType.h"
42+ #include " tensorrt_llm/batch_manager/kvCacheUtils.h"
4243#include " tensorrt_llm/batch_manager/llmRequest.h"
4344#include " tensorrt_llm/batch_manager/mlaCacheFormatter.h"
4445#include " tensorrt_llm/common/envUtils.h"
@@ -116,7 +117,6 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
116117 : mMpiGroupComm (std::addressof(tensorrt_llm::mpi::MpiComm::session()))
117118 , mCacheTransceiverConfig {cacheTransceiverConfig}
118119{
119- using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
120120 if (worldConfig.isPipelineParallel ())
121121 {
122122 mMpiGroupPipeParaComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
@@ -200,14 +200,12 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
200200 TLLM_THROW (" Unsupported cache transceiver backend type " );
201201 }
202202
203- using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
204203 auto makeFormatter = [cacheManager, isMLA, this ]()
205204 { return createCacheFormatter (cacheManager, mCacheTransBufferManager .get (), isMLA); };
206205
207- mDataResponder = std::make_unique<DataResponder>(
208- std::make_unique<DataSenderImpl>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ()));
209- mDataRequester = std::make_unique<DataRequester>(
210- std::make_unique<DataReceiverImpl>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ()));
206+ mCacheSender = std::make_unique<CacheSender>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ());
207+ mCacheReceiver
208+ = std::make_unique<CacheReceiver>(mManager .get (), *mCacheState , worldConfig.getRank (), makeFormatter ());
211209
212210 initializeCommState ();
213211}
@@ -223,7 +221,7 @@ CacheTransceiver::~CacheTransceiver()
223221
224222void CacheTransceiver::initializeCommState ()
225223{
226- mCommState = std::addressof (mDataResponder ->getCommState ());
224+ mCommState = std::addressof (mCacheSender ->getCommState ());
227225}
228226
229227void CacheTransceiver::setContextState (LlmRequest* llmRequest)
@@ -259,8 +257,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
259257 return ;
260258 }
261259 setContextState (llmRequest);
262- auto future = mDataResponder -> respondAndSendAsync (*llmRequest);
263- mResponderFutures .emplace_back (llmRequest, std::move (future));
260+ auto future = mCacheSender -> sendAsync (*llmRequest);
261+ mSenderFutures .emplace_back (llmRequest, std::move (future));
264262}
265263
266264void CacheTransceiver::respondAndSendLayerWise (
@@ -275,16 +273,16 @@ void CacheTransceiver::respondAndSendLayerWise(
275273
276274 llmRequest->setState (LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS );
277275 setContextState (llmRequest.get ());
278- auto future = mDataResponder -> respondAndSendAsync (*llmRequest);
279- mResponderFutures .emplace_back (llmRequest.get (), std::move (future));
276+ auto future = mCacheSender -> sendAsync (*llmRequest);
277+ mSenderFutures .emplace_back (llmRequest.get (), std::move (future));
280278 }
281279}
282280
283281void CacheTransceiver::requestAndReceiveSync (LlmRequest* llmRequest)
284282{
285283 TLLM_CHECK (llmRequest && llmRequest->isGenerationOnlyRequest ());
286284 {
287- auto future = mDataRequester -> requestAndReceiveAsync (*llmRequest);
285+ auto future = mCacheReceiver -> receiveAsync (*llmRequest);
288286 future.get ();
289287 }
290288 llmRequest->setState (LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE );
@@ -302,7 +300,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
302300 return ;
303301 }
304302
305- auto future = mDataRequester -> requestAndReceiveAsync (*llmRequest);
303+ auto future = mCacheReceiver -> receiveAsync (*llmRequest);
306304 mRequesterFutures .emplace_back (llmRequest, std::move (future));
307305 llmRequest->setState (LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS );
308306}
@@ -382,7 +380,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
382380 bool blockAll = !atLeastRequestNum.has_value ();
383381 auto syncComm = mCacheState ->getParallelConfig ().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm ;
384382 std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
385- for (auto && [request, future] : mResponderFutures )
383+ for (auto && [request, future] : mSenderFutures )
386384 {
387385 if (future.wait_for (std::chrono::milliseconds (0 )) == std::future_status::ready)
388386 {
@@ -422,16 +420,15 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
422420
423421 // Make sure there are at least atLeastRequestNum requests in toCompleteIdSet.
424422 // This will preserve the order of insertion for KVCache transfer requests.
425- for (auto it = mResponderFutures .begin ();
426- atLeastRequestNum.value_or (0 ) > static_cast <int >(toCompleteIdSet.size ()) && it != mResponderFutures .end ();
427- ++it)
423+ for (auto it = mSenderFutures .begin ();
424+ atLeastRequestNum.value_or (0 ) > static_cast <int >(toCompleteIdSet.size ()) && it != mSenderFutures .end (); ++it)
428425 {
429426 auto & [request, future] = *it;
430427 toCompleteIdSet.insert (request->mRequestId );
431428 }
432429
433430 // Complete all the requests in toCompleteIdSet
434- for (auto it = mResponderFutures .begin (); it != mResponderFutures .end ();)
431+ for (auto it = mSenderFutures .begin (); it != mSenderFutures .end ();)
435432 {
436433 auto & [request, future] = *it;
437434 if (blockAll || (toCompleteIdSet.find (request->mRequestId ) != toCompleteIdSet.end ()))
@@ -447,7 +444,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
447444 " Error occurred during context transfer for request %ld: %s" , request->mRequestId , e.what ());
448445 request->setState (LlmRequestState::kDISAGG_TRANS_ERROR );
449446 }
450- it = mResponderFutures .erase (it);
447+ it = mSenderFutures .erase (it);
451448 }
452449 else
453450 {
0 commit comments