Skip to content

Commit b3045c4

Browse files
authored
refactor: remove TrtGptModelOptionalParams (#5165)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
1 parent 4f0f17a commit b3045c4

File tree

16 files changed

+334
-518
lines changed

16 files changed

+334
-518
lines changed

cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h

Lines changed: 0 additions & 140 deletions
This file was deleted.

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "tensorrt_llm/runtime/decoderState.h"
2626
#include "tensorrt_llm/runtime/decodingInput.h"
2727
#include "tensorrt_llm/runtime/decodingOutput.h"
28-
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
2928
#include "tensorrt_llm/runtime/runtimeKernels.h"
3029
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
3130
#include "tensorrt_llm/runtime/utils/mpiUtils.h"

cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tensorrt_llm/common/assert.h"
2323
#include "tensorrt_llm/common/logger.h"
2424
#include "tensorrt_llm/common/nvtxUtils.h"
25+
#include "tensorrt_llm/executor/executor.h"
2526
#include "tensorrt_llm/runtime/iTensor.h"
2627
#include "tensorrt_llm/runtime/tllmLogger.h"
2728
#include "tensorrt_llm/runtime/tllmRuntime.h"
@@ -39,14 +40,14 @@ namespace tensorrt_llm::batch_manager
3940

4041
TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig,
4142
runtime::RawEngine const& rawEngine, std::shared_ptr<nvinfer1::ILogger> logger,
42-
TrtGptModelOptionalParams const& optionalParams)
43-
: TrtGptModel(modelConfig, worldConfig, optionalParams)
43+
executor::ExecutorConfig const& executorConfig)
44+
: TrtGptModel(modelConfig, worldConfig, executorConfig)
4445
, mModelConfig{modelConfig}
4546
, mWorldConfig{worldConfig}
4647
, mDevice{runtime::utils::initDevice(worldConfig)}
4748
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
4849
, mRuntime{std::make_shared<TllmRuntime>(
49-
rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage, optionalParams.gpuWeightsPercent)}
50+
rawEngine, mLogger.get(), executorConfig.getUseGpuDirectStorage(), executorConfig.getGpuWeightsPercent())}
5051
, mNumMicroBatches{1}
5152
, mNumBuffers{mNumMicroBatches}
5253
, mCopyBufferManager{std::make_shared<CudaStream>()}
@@ -75,8 +76,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC
7576
// handling of maximizing utilization or pause/evict
7677
// TODO: finer control on encoder requests scheduling
7778
mCapacityScheduler = std::make_unique<tensorrt_llm::batch_manager::CapacityScheduler>(
78-
getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, false,
79-
LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
79+
getMaxBatchSize() * mNumMicroBatches, executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), false,
80+
false, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
8081

8182
mMicroBatchScheduler = std::make_unique<tensorrt_llm::batch_manager::MicroBatchScheduler>(
8283
std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);

cpp/tensorrt_llm/batch_manager/trtEncoderModel.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
#pragma once
1919

20-
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
2120
#include "tensorrt_llm/runtime/rawEngine.h"
2221
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
2322
#include "trtGptModel.h"
@@ -47,7 +46,7 @@ class TrtEncoderModel : public TrtGptModel
4746

4847
TrtEncoderModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
4948
runtime::RawEngine const& rawEngine, std::shared_ptr<nvinfer1::ILogger> logger,
50-
TrtGptModelOptionalParams const& optionalParams);
49+
executor::ExecutorConfig const& executorConfig);
5150

5251
~TrtEncoderModel() override;
5352

cpp/tensorrt_llm/batch_manager/trtGptModel.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
#pragma once
1919

2020
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
21-
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
2221
#include "tensorrt_llm/common/assert.h"
2322
#include "tensorrt_llm/common/stlUtils.h"
2423
#include "tensorrt_llm/executor/executor.h"
@@ -52,23 +51,23 @@ class TrtGptModel : public executor::Model
5251
using SizeType32 = tensorrt_llm::runtime::SizeType32;
5352

5453
TrtGptModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
55-
TrtGptModelOptionalParams const& optionalParams)
56-
: mMaxBatchSize{optionalParams.maxBatchSize.value_or(modelConfig.getMaxBatchSize())}
57-
, mMaxBeamWidth{optionalParams.maxBeamWidth.value_or(modelConfig.getMaxBeamWidth())}
54+
executor::ExecutorConfig const& executorConfig)
55+
: mMaxBatchSize{executorConfig.getMaxBatchSize().value_or(modelConfig.getMaxBatchSize())}
56+
, mMaxBeamWidth{executorConfig.getMaxBeamWidth()}
5857
, mMaxSequenceLen{modelConfig.getMaxSequenceLen()}
5958
, mMaxDraftLen{modelConfig.getMaxDecodingDraftTokens()}
6059
, mVocabSizePadded{modelConfig.getVocabSizePadded(worldConfig.getSize())}
61-
, mNormalizeLogProbs{optionalParams.normalizeLogProbs}
62-
, mEnableTrtOverlap{optionalParams.enableTrtOverlap}
63-
, mCudaGraphMode{optionalParams.extendedRuntimePerfKnobConfig.getCudaGraphMode()}
60+
, mNormalizeLogProbs{executorConfig.getNormalizeLogProbs()}
61+
, mEnableTrtOverlap{executorConfig.getEnableTrtOverlap()}
62+
, mCudaGraphMode{executorConfig.getExtendedRuntimePerfKnobConfig().getCudaGraphMode()}
6463
{
6564
TLLM_CHECK_WITH_INFO(mMaxBeamWidth <= modelConfig.getMaxBeamWidth(),
6665
"Runtime configured max beam width (%d) must not exceed engine max beam width (%d)", mMaxBeamWidth,
6766
modelConfig.getMaxBeamWidth());
6867
TLLM_CHECK_WITH_INFO(mMaxBatchSize <= modelConfig.getMaxBatchSize(),
6968
"Runtime configured max batch size (%d) must not exceed engine max batch size (%d)", mMaxBatchSize,
7069
modelConfig.getMaxBatchSize());
71-
if (optionalParams.enableTrtOverlap)
70+
if (executorConfig.getEnableTrtOverlap())
7271
{
7372
if (mMaxBeamWidth > 1)
7473
{
@@ -85,10 +84,11 @@ class TrtGptModel : public executor::Model
8584
}
8685

8786
mMaxAttentionWindow = 0;
88-
if (optionalParams.kvCacheConfig.maxAttentionWindowVec.has_value())
87+
if (executorConfig.getKvCacheConfig().getMaxAttentionWindowVec().has_value())
8988
{
9089
bool warning = false;
91-
for (int maxAttenWin : optionalParams.kvCacheConfig.maxAttentionWindowVec.value())
90+
auto const& maxAttentionWindowVec = executorConfig.getKvCacheConfig().getMaxAttentionWindowVec();
91+
for (int maxAttenWin : maxAttentionWindowVec.value())
9292
{
9393
mMaxAttentionWindowVec.push_back(std::min(maxAttenWin, mMaxSequenceLen));
9494
mMaxAttentionWindow = std::max(mMaxAttentionWindow, mMaxAttentionWindowVec.back());
@@ -112,8 +112,8 @@ class TrtGptModel : public executor::Model
112112
mMaxAttentionWindow = mMaxSequenceLen;
113113
}
114114

115-
mSinkTokenLen = optionalParams.kvCacheConfig.sinkTokenLength.has_value()
116-
? optionalParams.kvCacheConfig.sinkTokenLength.value()
115+
mSinkTokenLen = executorConfig.getKvCacheConfig().getSinkTokenLength().has_value()
116+
? executorConfig.getKvCacheConfig().getSinkTokenLength().value()
117117
: 0;
118118

119119
mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism();
@@ -136,26 +136,26 @@ class TrtGptModel : public executor::Model
136136
TLLM_LOG_INFO("TRTGptModel normalizeLogProbs: %d", mNormalizeLogProbs);
137137

138138
mMaxNumTokens = modelConfig.getMaxNumTokens();
139-
if (optionalParams.maxNumTokens && mMaxNumTokens)
139+
if (executorConfig.getMaxNumTokens().has_value() && mMaxNumTokens)
140140
{
141-
if (optionalParams.maxNumTokens.value() > mMaxNumTokens.value())
141+
if (executorConfig.getMaxNumTokens().value() > mMaxNumTokens.value())
142142
{
143143
TLLM_LOG_WARNING(
144144
"Runtime configured max num tokens (%d) is larger than model max num tokens (%d) and will be "
145145
"ignored.",
146-
optionalParams.maxNumTokens.value(), mMaxNumTokens.value());
146+
executorConfig.getMaxNumTokens().value(), mMaxNumTokens.value());
147147
}
148148
else
149149
{
150-
mMaxNumTokens = optionalParams.maxNumTokens;
150+
mMaxNumTokens = executorConfig.getMaxNumTokens();
151151
}
152152
}
153153
if (mMaxNumTokens)
154154
{
155155
TLLM_LOG_INFO("TRTGptModel maxNumTokens: %d", mMaxNumTokens.value());
156156
}
157157

158-
if (optionalParams.enableChunkedContext)
158+
if (executorConfig.getEnableChunkedContext())
159159
{
160160
mMaxInputLen = mMaxSequenceLen - 1;
161161
TLLM_LOG_INFO(
@@ -199,9 +199,9 @@ class TrtGptModel : public executor::Model
199199
using tensorrt_llm::common::stl_utils::toString;
200200

201201
TLLM_LOG_INFO("Capacity Scheduler Policy: %s",
202-
toString(optionalParams.schedulerConfig.getCapacitySchedulerPolicy()).c_str());
202+
toString(executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy()).c_str());
203203
TLLM_LOG_INFO("Context Chunking Scheduler Policy: %s",
204-
toString(optionalParams.schedulerConfig.getContextChunkingPolicy()).c_str());
204+
toString(executorConfig.getSchedulerConfig().getContextChunkingPolicy()).c_str());
205205
}
206206

207207
[[nodiscard]] std::optional<SizeType32> getMaxNumTokens() const

0 commit comments

Comments
 (0)