Skip to content

Commit 6ce0624

Browse files
authored
[TRTLLM-8044][refactor] Rename data -> cache for cacheTransceiver (#7659)
1 parent 8226ef2 commit 6ce0624

File tree

18 files changed

+614
-809
lines changed

18 files changed

+614
-809
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,14 @@ namespace tensorrt_llm::batch_manager
3434

3535
class ContextProgress;
3636
class BaseCacheTransceiver;
37-
class DataResponder;
38-
class DataRequester;
37+
38+
namespace kv_cache_manager
39+
{
40+
class BaseKVCacheManager;
41+
} // namespace kv_cache_manager
42+
43+
class CacheSender;
44+
class CacheReceiver;
3945

4046
class CacheTransceiverFactory
4147
{
@@ -110,9 +116,9 @@ class CacheTransceiver : public BaseCacheTransceiver
110116

111117
void setContextState(LlmRequest* llmRequest);
112118

113-
std::unique_ptr<DataResponder> mDataResponder;
114-
std::unique_ptr<DataRequester> mDataRequester;
115-
std::vector<std::pair<LlmRequest*, std::future<void>>> mResponderFutures;
119+
std::unique_ptr<CacheSender> mCacheSender;
120+
std::unique_ptr<CacheReceiver> mCacheReceiver;
121+
std::vector<std::pair<LlmRequest*, std::future<void>>> mSenderFutures;
116122
std::vector<std::pair<LlmRequest*, std::future<void>>> mRequesterFutures;
117123
mpi::MpiComm const *mMpiGroupComm{nullptr}, *mMpiWorldComm{nullptr};
118124
std::shared_ptr<mpi::MpiComm> mMpiGroupTensorParaComm, mMpiGroupPipeParaComm, mMpiGroupDataComm,

cpp/tensorrt_llm/batch_manager/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ set(SRCS
2424
createNewDecoderRequests.cpp
2525
contextProgress.cpp
2626
dataTransceiver.cpp
27-
dataTransceiverImpl.cpp
2827
decoderBuffers.cpp
2928
encoderBuffers.cpp
3029
guidedDecoder.cpp

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlaCacheFormatter.h"
2020

2121
#include "tensorrt_llm/batch_manager/contextProgress.h"
22+
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
2223
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
2324
#include "tensorrt_llm/common/assert.h"
2425
#include "tensorrt_llm/common/cudaUtils.h"
@@ -154,7 +155,7 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
154155
return ret;
155156
}
156157

157-
void CacheFormatter::format(TransferSession& session)
158+
void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& session)
158159
{
159160
NVTX3_SCOPED_RANGE(CacheFormatter_format);
160161
auto const& llmRequest = session.getLlmRequest();
@@ -468,7 +469,7 @@ void CacheFormatter::format(TransferSession& session)
468469
mpi::MpiComm::world().getRank(), "End the sending of KV cache for the request ID:%ld ", llmRequest.mRequestId);
469470
}
470471

471-
void CacheFormatter::unformat(TransferSession& session)
472+
void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& session)
472473
{
473474
NVTX3_SCOPED_RANGE(CacheFormatter_unformat);
474475
auto const& llmRequest = session.getLlmRequest();

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,38 @@
1818
#pragma once
1919

2020
#include "cacheTransBuffer.h"
21-
#include "dataTransceiver.h"
2221
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2322
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
23+
#include "tensorrt_llm/common/assert.h"
2424
#include "tensorrt_llm/common/envUtils.h"
2525
#include "tensorrt_llm/common/logger.h"
26+
#include "tensorrt_llm/executor/cacheCommunicator.h"
2627
#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h"
2728
#include "tensorrt_llm/executor/dataTransceiverState.h"
2829
#include "tensorrt_llm/runtime/bufferManager.h"
2930
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
3031
#include <NvInferRuntimeBase.h>
3132
#include <cstddef>
3233
#include <cstdint>
34+
#include <fstream>
35+
#include <vector>
36+
37+
// Forward declare TransferSession in the correct global namespace scope
38+
namespace tensorrt_llm::batch_manager
39+
{
40+
class TransferSession;
41+
}
3342

3443
namespace tensorrt_llm::batch_manager::kv_cache_manager
3544
{
3645

46+
using DataContext = tensorrt_llm::executor::kv_cache::DataContext;
47+
using Connection = tensorrt_llm::executor::kv_cache::Connection;
48+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
49+
using BaseKVCacheManager = kv_cache_manager::BaseKVCacheManager;
50+
using CacheTransBufferManager = kv_cache_manager::CacheTransBufferManager;
51+
using BlockRange = kv_cache_manager::BlockRange;
52+
3753
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
3854

3955
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
@@ -42,16 +58,15 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
4258
class BaseCacheFormatter
4359
{
4460
public:
45-
using SizeType32 = tensorrt_llm::runtime::SizeType32;
4661
using CacheState = executor::kv_cache::CacheState;
4762

4863
/// @brief Format the cache data into bytes for sending.
4964
/// @param session The transfer session.
50-
virtual void format(TransferSession& session) = 0;
65+
virtual void format(tensorrt_llm::batch_manager::TransferSession& session) = 0;
5166

5267
/// @brief Unformat the cache data from received bytes.
5368
/// @param session The transfer session.
54-
virtual void unformat(TransferSession& session) = 0;
69+
virtual void unformat(tensorrt_llm::batch_manager::TransferSession& session) = 0;
5570

5671
/// @brief Determine whether the sender is applicable to the source and target.
5772
/// @param selfConfig Source data arrangement.
@@ -91,9 +106,9 @@ class CacheFormatter final : public BaseCacheFormatter
91106
TLLM_CHECK(mCacheTransBufferManager);
92107
}
93108

94-
void format(TransferSession& session) override;
109+
void format(tensorrt_llm::batch_manager::TransferSession& session) override;
95110

96-
void unformat(TransferSession& session) override;
111+
void unformat(tensorrt_llm::batch_manager::TransferSession& session) override;
97112

98113
[[nodiscard]] bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const override;
99114

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
3838
#include "tensorrt_llm/batch_manager/cacheTransceiver.h"
3939
#include "tensorrt_llm/batch_manager/contextProgress.h"
40-
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
4140
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
41+
#include "tensorrt_llm/batch_manager/kvCacheType.h"
42+
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
4243
#include "tensorrt_llm/batch_manager/llmRequest.h"
4344
#include "tensorrt_llm/batch_manager/mlaCacheFormatter.h"
4445
#include "tensorrt_llm/common/envUtils.h"
@@ -116,7 +117,6 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
116117
: mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session()))
117118
, mCacheTransceiverConfig{cacheTransceiverConfig}
118119
{
119-
using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter;
120120
if (worldConfig.isPipelineParallel())
121121
{
122122
mMpiGroupPipeParaComm = std::make_shared<tensorrt_llm::mpi::MpiComm>(
@@ -200,14 +200,12 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
200200
TLLM_THROW("Unsupported cache transceiver backend type ");
201201
}
202202

203-
using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
204203
auto makeFormatter = [cacheManager, isMLA, this]()
205204
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
206205

207-
mDataResponder = std::make_unique<DataResponder>(
208-
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
209-
mDataRequester = std::make_unique<DataRequester>(
210-
std::make_unique<DataReceiverImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));
206+
mCacheSender = std::make_unique<CacheSender>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
207+
mCacheReceiver
208+
= std::make_unique<CacheReceiver>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter());
211209

212210
initializeCommState();
213211
}
@@ -223,7 +221,7 @@ CacheTransceiver::~CacheTransceiver()
223221

224222
void CacheTransceiver::initializeCommState()
225223
{
226-
mCommState = std::addressof(mDataResponder->getCommState());
224+
mCommState = std::addressof(mCacheSender->getCommState());
227225
}
228226

229227
void CacheTransceiver::setContextState(LlmRequest* llmRequest)
@@ -259,8 +257,8 @@ void CacheTransceiver::respondAndSendAsync(LlmRequest* llmRequest)
259257
return;
260258
}
261259
setContextState(llmRequest);
262-
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
263-
mResponderFutures.emplace_back(llmRequest, std::move(future));
260+
auto future = mCacheSender->sendAsync(*llmRequest);
261+
mSenderFutures.emplace_back(llmRequest, std::move(future));
264262
}
265263

266264
void CacheTransceiver::respondAndSendLayerWise(
@@ -275,16 +273,16 @@ void CacheTransceiver::respondAndSendLayerWise(
275273

276274
llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
277275
setContextState(llmRequest.get());
278-
auto future = mDataResponder->respondAndSendAsync(*llmRequest);
279-
mResponderFutures.emplace_back(llmRequest.get(), std::move(future));
276+
auto future = mCacheSender->sendAsync(*llmRequest);
277+
mSenderFutures.emplace_back(llmRequest.get(), std::move(future));
280278
}
281279
}
282280

283281
void CacheTransceiver::requestAndReceiveSync(LlmRequest* llmRequest)
284282
{
285283
TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest());
286284
{
287-
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
285+
auto future = mCacheReceiver->receiveAsync(*llmRequest);
288286
future.get();
289287
}
290288
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
@@ -302,7 +300,7 @@ void CacheTransceiver::requestAndReceiveAsync(LlmRequest* llmRequest)
302300
return;
303301
}
304302

305-
auto future = mDataRequester->requestAndReceiveAsync(*llmRequest);
303+
auto future = mCacheReceiver->receiveAsync(*llmRequest);
306304
mRequesterFutures.emplace_back(llmRequest, std::move(future));
307305
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
308306
}
@@ -382,7 +380,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
382380
bool blockAll = !atLeastRequestNum.has_value();
383381
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mMpiGroupTPInDPComm : mMpiGroupTensorParaComm;
384382
std::vector<LlmRequest::RequestIdType> contextCompleteRequestIds;
385-
for (auto&& [request, future] : mResponderFutures)
383+
for (auto&& [request, future] : mSenderFutures)
386384
{
387385
if (future.wait_for(std::chrono::milliseconds(0)) == std::future_status::ready)
388386
{
@@ -422,16 +420,15 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
422420

423421
// Make sure there are at least atLeastRequestNum requests in toCompleteIdSet.
424422
// This will preserve the order of insertion for KVCache transfer requests.
425-
for (auto it = mResponderFutures.begin();
426-
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mResponderFutures.end();
427-
++it)
423+
for (auto it = mSenderFutures.begin();
424+
atLeastRequestNum.value_or(0) > static_cast<int>(toCompleteIdSet.size()) && it != mSenderFutures.end(); ++it)
428425
{
429426
auto& [request, future] = *it;
430427
toCompleteIdSet.insert(request->mRequestId);
431428
}
432429

433430
// Complete all the requests in toCompleteIdSet
434-
for (auto it = mResponderFutures.begin(); it != mResponderFutures.end();)
431+
for (auto it = mSenderFutures.begin(); it != mSenderFutures.end();)
435432
{
436433
auto& [request, future] = *it;
437434
if (blockAll || (toCompleteIdSet.find(request->mRequestId) != toCompleteIdSet.end()))
@@ -447,7 +444,7 @@ void CacheTransceiver::checkContextTransferStatus(std::optional<int> const& atLe
447444
"Error occurred during context transfer for request %ld: %s", request->mRequestId, e.what());
448445
request->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
449446
}
450-
it = mResponderFutures.erase(it);
447+
it = mSenderFutures.erase(it);
451448
}
452449
else
453450
{

0 commit comments

Comments
 (0)