Skip to content

Commit ee44fa0

Browse files
authored
chore: rename IOFormatter to BaseCacheFormatter (NVIDIA#5068)
Signed-off-by: Zheng Duan <[email protected]>
1 parent ad99a08 commit ee44fa0

File tree

8 files changed

+87
-114
lines changed

8 files changed

+87
-114
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
#include "cacheFormatter.h"
19+
#include "mlaCacheFormatter.h"
1920

2021
#include "tensorrt_llm/batch_manager/contextProgress.h"
2122
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
@@ -751,4 +752,15 @@ void CacheFormatter::formatInput(LlmRequest const& llmRequest,
751752
}
752753
return true;
753754
}
755+
756+
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
757+
BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA)
758+
{
759+
if (isMLA)
760+
{
761+
return std::make_unique<MLACacheFormatter>(cacheManager, cacheTransBufferManager);
762+
}
763+
return std::make_unique<CacheFormatter>(cacheManager, cacheTransBufferManager);
764+
}
765+
754766
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
#pragma once
1919

20+
#include "cacheTransBuffer.h"
2021
#include "dataTransceiver.h"
21-
#include "tensorrt_llm/batch_manager/cacheTransBuffer.h"
2222
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2323
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
2424
#include "tensorrt_llm/common/envUtils.h"
2525
#include "tensorrt_llm/common/logger.h"
26+
#include "tensorrt_llm/executor/cacheCommunicator.h"
2627
#include "tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
2728
#include "tensorrt_llm/executor/dataTransceiverState.h"
2829
#include "tensorrt_llm/runtime/bufferManager.h"
@@ -60,13 +61,54 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
6061

6162
BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest);
6263

63-
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
64-
// parallel topology is completely identical, making it the preferred method.
65-
class CacheFormatter final : public IOFormatter
64+
// Used to support the cache transmission with different layouts and different protocols.
65+
class BaseCacheFormatter
6666
{
6767
public:
68+
using SizeType32 = tensorrt_llm::runtime::SizeType32;
6869
using CacheState = executor::kv_cache::CacheState;
6970

71+
virtual void formatOutput(LlmRequest const& llmRequest,
72+
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
73+
SizeType32 selfIdx, CacheState const& destConfig, runtime::BufferManager const& bufferManager)
74+
= 0;
75+
76+
virtual void formatInput(LlmRequest const& llmRequest,
77+
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
78+
SizeType32 selfIdx, CacheState const& destConfig, runtime::BufferManager const& bufferManager)
79+
= 0;
80+
81+
/// @brief Determine whether the sender is applicable to the source and target.
82+
/// @param selfConfig Source data arrangement.
83+
/// @param destConfig Target data arrangement.
84+
/// @return Whether the sender is applicable to the source and target.
85+
[[nodiscard]] virtual bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const = 0;
86+
87+
/// @brief Obtain the indies of the counterparts that need to be actually communicated with.
88+
/// @param selfConfig Source data arrangement.
89+
/// @param selfIdx The sequential index of the current executor process within the entire parallel group.
90+
/// @param destConfig Target data arrangement.
91+
/// @return The indies of the counterparts.
92+
[[nodiscard]] virtual std::vector<SizeType32> getCounterparts(
93+
CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const
94+
= 0;
95+
96+
[[nodiscard]] virtual BaseKVCacheManager* getCacheManager() const noexcept = 0;
97+
98+
[[nodiscard]] virtual std::vector<executor::kv_cache::Connection const*> pickRecvConnections(
99+
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
100+
SizeType32 selfIdx, CacheState const& destConfig) const
101+
= 0;
102+
103+
/// @brief Destructor.
104+
virtual ~BaseCacheFormatter() = default;
105+
};
106+
107+
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
108+
// parallel topology is completely identical, making it the preferred method.
109+
class CacheFormatter final : public BaseCacheFormatter
110+
{
111+
public:
70112
CacheFormatter(BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager)
71113
: mCacheManager{cacheManager}
72114
, mCacheTransBufferManager{cacheTransBufferManager}
@@ -91,7 +133,7 @@ class CacheFormatter final : public IOFormatter
91133
return executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx).mIRanks;
92134
}
93135

94-
BaseKVCacheManager* getCacheManager() const noexcept
136+
[[nodiscard]] BaseKVCacheManager* getCacheManager() const noexcept override
95137
{
96138
return mCacheManager;
97139
}
@@ -102,11 +144,12 @@ class CacheFormatter final : public IOFormatter
102144
SizeType32 selfIdx, CacheState const& destConfig) const override;
103145

104146
private:
105-
BaseKVCacheManager* mCacheManager{};
106-
147+
BaseKVCacheManager* mCacheManager;
107148
CacheTransBufferManager* mCacheTransBufferManager;
108-
109149
KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()};
110150
};
111151

152+
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
153+
BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA = false);
154+
112155
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,8 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa
179179
}
180180

181181
using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter;
182-
auto makeFormatter = [cacheManager, isMLA, this]() -> std::unique_ptr<IOFormatter>
183-
{
184-
return isMLA ? std::unique_ptr<IOFormatter>(
185-
std::make_unique<MLACacheFormatter>(cacheManager, this->mCacheTransBufferManager.get()))
186-
: std::unique_ptr<IOFormatter>(
187-
std::make_unique<CacheFormatter>(cacheManager, this->mCacheTransBufferManager.get()));
188-
};
182+
auto makeFormatter = [cacheManager, isMLA, this]()
183+
{ return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); };
189184

190185
mDataResponder = std::make_unique<DataResponder>(
191186
std::make_unique<DataSenderImpl>(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter()));

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,47 +34,6 @@
3434
namespace tensorrt_llm::batch_manager
3535
{
3636

37-
// Used to support the data transmission with different layouts and different protocols.
38-
class IOFormatter
39-
{
40-
public:
41-
using SizeType32 = tensorrt_llm::runtime::SizeType32;
42-
using CacheState = executor::kv_cache::CacheState;
43-
44-
virtual void formatOutput(LlmRequest const& llmRequest,
45-
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
46-
SizeType32 selfIdx, CacheState const& destConfig, runtime::BufferManager const& bufferManager)
47-
= 0;
48-
49-
virtual void formatInput(LlmRequest const& llmRequest,
50-
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
51-
SizeType32 selfIdx, CacheState const& destConfig, runtime::BufferManager const& bufferManager)
52-
= 0;
53-
54-
/// @brief Determine whether the sender is applicable to the source and target.
55-
/// @param selfConfig Source data arrangement.
56-
/// @param destConfig Target data arrangement.
57-
/// @return Whether the sender is applicable to the source and target.
58-
[[nodiscard]] virtual bool inquireSupport(CacheState const& selfConfig, CacheState const& destConfig) const = 0;
59-
60-
/// @brief Obtain the indies of the counterparts that need to be actually communicated with.
61-
/// @param selfConfig Source data arrangement.
62-
/// @param selfIdx The sequential index of the current executor process within the entire parallel group.
63-
/// @param destConfig Target data arrangement.
64-
/// @return The indies of the counterparts.
65-
[[nodiscard]] virtual std::vector<SizeType32> getCounterparts(
66-
CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const
67-
= 0;
68-
69-
[[nodiscard]] virtual std::vector<executor::kv_cache::Connection const*> pickRecvConnections(
70-
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
71-
SizeType32 selfIdx, CacheState const& destConfig) const
72-
= 0;
73-
74-
/// @brief Destructor.
75-
virtual ~IOFormatter() = default;
76-
};
77-
7837
// Used to store the information that needs to be sent to the context executor to ensure the generation
7938
// executor smoothly receives the data.
8039
class RequestInfo

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@
1515
* limitations under the License.
1616
*/
1717

18-
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
19-
#include "tensorrt_llm/batch_manager/cacheFormatter.h"
20-
#include "tensorrt_llm/batch_manager/dataTransceiverImpl.h"
21-
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
22-
#include "tensorrt_llm/batch_manager/mlaCacheFormatter.h"
18+
#include "dataTransceiverImpl.h"
19+
2320
#include "tensorrt_llm/common/envUtils.h"
2421
#include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h"
2522
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
@@ -28,7 +25,7 @@ namespace tensorrt_llm::batch_manager
2825
{
2926

3027
DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager,
31-
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<IOFormatter> formatter)
28+
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
3229
: mManager{manager}
3330
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
3431
, mFormatter(std::move(formatter))
@@ -133,7 +130,7 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId)
133130
}
134131

135132
DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manager,
136-
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<IOFormatter> formatter)
133+
executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter)
137134
: mManager{manager}
138135
, mSelfState{std::move(selfCacheState), executor::kv_cache::CommState{manager->getCommState()}}
139136
, mFormatter(std::move(formatter))
@@ -156,23 +153,10 @@ void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
156153

157154
if (!common::getEnvDisableSelectiveCacheTransfer())
158155
{
159-
// TODO: remove IOFormatter and make CacheFormatter new base class
160-
auto* cacheFormatter = dynamic_cast<kv_cache_manager::CacheFormatter const*>(mFormatter.get());
161-
auto* mlaCacheFormatter = dynamic_cast<kv_cache_manager::MLACacheFormatter const*>(mFormatter.get());
162-
if (cacheFormatter != nullptr)
163-
{
164-
auto* cacheManager = cacheFormatter->getCacheManager();
165-
auto blockRange
166-
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
167-
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
168-
}
169-
else if (mlaCacheFormatter != nullptr)
170-
{
171-
auto* cacheManager = mlaCacheFormatter->getCacheManager();
172-
auto blockRange
173-
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
174-
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
175-
}
156+
auto* cacheManager = mFormatter->getCacheManager();
157+
auto blockRange
158+
= kv_cache_manager::BlockRange::fromNewlyAllocatedBlockIds(*cacheManager, llmRequest.mRequestId);
159+
requestInfo = RequestInfo(requestId, blockRange.getBlockHashes(), mSelfState);
176160
}
177161

178162
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717

1818
#pragma once
1919

20-
#include "tensorrt_llm/batch_manager/cacheTransBuffer.h"
21-
#include "tensorrt_llm/batch_manager/dataTransceiver.h"
22-
#include "tensorrt_llm/common/envUtils.h"
23-
#include "tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
20+
#include "cacheFormatter.h"
21+
#include "dataTransceiver.h"
2422

2523
namespace tensorrt_llm::batch_manager
2624
{
@@ -37,6 +35,8 @@ struct TransceiverTag
3735
static constexpr int32_t kINFO_TAG{32};
3836
};
3937

38+
using BaseCacheFormatter = kv_cache_manager::BaseCacheFormatter;
39+
4040
class DataSenderImpl : public DataSender, public TransceiverTag
4141
{
4242
public:
@@ -45,7 +45,7 @@ class DataSenderImpl : public DataSender, public TransceiverTag
4545
= std::vector<std::pair<executor::kv_cache::Connection const*, executor::DataTransceiverState>>;
4646

4747
DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
48-
SizeType32 selfIndex, std::unique_ptr<IOFormatter> formatter);
48+
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
4949

5050
[[nodiscard]] RequestInfo recvRequestInfo() override;
5151

@@ -63,7 +63,7 @@ class DataSenderImpl : public DataSender, public TransceiverTag
6363
executor::kv_cache::ConnectionManager* mManager;
6464
std::map<LlmRequest::RequestIdType, RequestMapInfo> mRequestToComms;
6565
executor::DataTransceiverState mSelfState;
66-
std::unique_ptr<IOFormatter> mFormatter;
66+
std::unique_ptr<BaseCacheFormatter> mFormatter;
6767
std::mutex mMtxForMap;
6868
runtime::BufferManager mBufferManager;
6969
};
@@ -74,7 +74,7 @@ class DataReceiverImpl : public DataReceiver, public TransceiverTag
7474
using SizeType32 = tensorrt_llm::runtime::SizeType32;
7575

7676
DataReceiverImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState,
77-
SizeType32 selfIndex, std::unique_ptr<IOFormatter> formatter);
77+
SizeType32 selfIndex, std::unique_ptr<BaseCacheFormatter> formatter);
7878

7979
void sendRequestInfo(LlmRequest const& llmRequest) override;
8080

@@ -99,7 +99,7 @@ class DataReceiverImpl : public DataReceiver, public TransceiverTag
9999

100100
executor::kv_cache::ConnectionManager* mManager;
101101
executor::DataTransceiverState mSelfState;
102-
std::unique_ptr<IOFormatter> mFormatter;
102+
std::unique_ptr<BaseCacheFormatter> mFormatter;
103103
std::unordered_map<std::string, std::unique_ptr<ReceiveCacheResource>> mProcessToResources;
104104
std::mutex mProcessIoResouceMutex;
105105
};

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,31 +17,16 @@
1717

1818
#pragma once
1919

20-
#include "dataTransceiver.h"
21-
#include "tensorrt_llm/batch_manager/cacheTransBuffer.h"
22-
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
23-
#include "tensorrt_llm/batch_manager/kvCacheUtils.h"
24-
#include "tensorrt_llm/common/logger.h"
25-
#include "tensorrt_llm/executor/cache_transmission/cacheConcatenate.h"
26-
#include "tensorrt_llm/executor/dataTransceiverState.h"
27-
#include "tensorrt_llm/runtime/bufferManager.h"
28-
#include "tensorrt_llm/runtime/iTensor.h"
29-
#include <NvInferRuntimeBase.h>
30-
#include <condition_variable>
31-
#include <cstddef>
32-
#include <cstdint>
33-
#include <iterator>
20+
#include "cacheFormatter.h"
3421

3522
namespace tensorrt_llm::batch_manager::kv_cache_manager
3623
{
3724

3825
// Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the
3926
// parallel topology is completely identical, making it the preferred method.
40-
class MLACacheFormatter final : public IOFormatter
27+
class MLACacheFormatter final : public BaseCacheFormatter
4128
{
4229
public:
43-
using CacheState = executor::kv_cache::CacheState;
44-
4530
MLACacheFormatter(BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager)
4631
: mCacheManager{cacheManager}
4732
, mCacheTransBufferManager{cacheTransBufferManager}
@@ -66,7 +51,7 @@ class MLACacheFormatter final : public IOFormatter
6651
return executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx).mIRanks;
6752
}
6853

69-
[[nodiscard]] BaseKVCacheManager* getCacheManager() const
54+
[[nodiscard]] BaseKVCacheManager* getCacheManager() const noexcept override
7055
{
7156
return mCacheManager;
7257
}
@@ -77,7 +62,7 @@ class MLACacheFormatter final : public IOFormatter
7762
SizeType32 selfIdx, CacheState const& destConfig) const override;
7863

7964
private:
80-
BaseKVCacheManager* mCacheManager{};
65+
BaseKVCacheManager* mCacheManager;
8166
CacheTransBufferManager* mCacheTransBufferManager;
8267
};
8368

cpp/tests/batch_manager/cacheTransceiverTest.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -734,13 +734,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
734734
mConnectionManager = std::make_unique<texec::kv_cache::MpiConnectionManager>(mComm);
735735
}
736736

737-
auto makeFormatter = [this]()
738-
{
739-
return mIsMLA ? std::unique_ptr<IOFormatter>(
740-
std::make_unique<MLACacheFormatter>(mManager.get(), mCacheTransBufferManager.get()))
741-
: std::unique_ptr<IOFormatter>(
742-
std::make_unique<CacheFormatter>(mManager.get(), mCacheTransBufferManager.get()));
743-
};
737+
auto makeFormatter
738+
= [this]() { return createCacheFormatter(mManager.get(), mCacheTransBufferManager.get(), mIsMLA); };
744739

745740
if (mIsContext)
746741
{

0 commit comments

Comments
 (0)