1818#pragma once
1919
2020#include " tensorrt_llm/batch_manager/peftCacheManager.h"
21- #include " tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
2221#include " tensorrt_llm/common/assert.h"
2322#include " tensorrt_llm/common/stlUtils.h"
2423#include " tensorrt_llm/executor/executor.h"
@@ -52,23 +51,23 @@ class TrtGptModel : public executor::Model
5251 using SizeType32 = tensorrt_llm::runtime::SizeType32;
5352
5453 TrtGptModel (runtime::ModelConfig const & modelConfig, runtime::WorldConfig const & worldConfig,
55- TrtGptModelOptionalParams const & optionalParams )
56- : mMaxBatchSize {optionalParams. maxBatchSize .value_or (modelConfig.getMaxBatchSize ())}
57- , mMaxBeamWidth {optionalParams. maxBeamWidth . value_or (modelConfig. getMaxBeamWidth () )}
54+ executor::ExecutorConfig const & executorConfig )
55+ : mMaxBatchSize {executorConfig. getMaxBatchSize () .value_or (modelConfig.getMaxBatchSize ())}
56+ , mMaxBeamWidth {executorConfig. getMaxBeamWidth ()}
5857 , mMaxSequenceLen {modelConfig.getMaxSequenceLen ()}
5958 , mMaxDraftLen {modelConfig.getMaxDecodingDraftTokens ()}
6059 , mVocabSizePadded {modelConfig.getVocabSizePadded (worldConfig.getSize ())}
61- , mNormalizeLogProbs {optionalParams. normalizeLogProbs }
62- , mEnableTrtOverlap {optionalParams. enableTrtOverlap }
63- , mCudaGraphMode {optionalParams. extendedRuntimePerfKnobConfig .getCudaGraphMode ()}
60+ , mNormalizeLogProbs {executorConfig. getNormalizeLogProbs () }
61+ , mEnableTrtOverlap {executorConfig. getEnableTrtOverlap () }
62+ , mCudaGraphMode {executorConfig. getExtendedRuntimePerfKnobConfig () .getCudaGraphMode ()}
6463 {
6564 TLLM_CHECK_WITH_INFO (mMaxBeamWidth <= modelConfig.getMaxBeamWidth (),
6665 " Runtime configured max beam width (%d) must not exceed engine max beam width (%d)" , mMaxBeamWidth ,
6766 modelConfig.getMaxBeamWidth ());
6867 TLLM_CHECK_WITH_INFO (mMaxBatchSize <= modelConfig.getMaxBatchSize (),
6968 " Runtime configured max batch size (%d) must not exceed engine max batch size (%d)" , mMaxBatchSize ,
7069 modelConfig.getMaxBatchSize ());
71- if (optionalParams. enableTrtOverlap )
70+ if (executorConfig. getEnableTrtOverlap () )
7271 {
7372 if (mMaxBeamWidth > 1 )
7473 {
@@ -85,10 +84,11 @@ class TrtGptModel : public executor::Model
8584 }
8685
8786 mMaxAttentionWindow = 0 ;
88- if (optionalParams. kvCacheConfig . maxAttentionWindowVec .has_value ())
87+ if (executorConfig. getKvCacheConfig (). getMaxAttentionWindowVec () .has_value ())
8988 {
9089 bool warning = false ;
91- for (int maxAttenWin : optionalParams.kvCacheConfig .maxAttentionWindowVec .value ())
90+ auto const & maxAttentionWindowVec = executorConfig.getKvCacheConfig ().getMaxAttentionWindowVec ();
91+ for (int maxAttenWin : maxAttentionWindowVec.value ())
9292 {
9393 mMaxAttentionWindowVec .push_back (std::min (maxAttenWin, mMaxSequenceLen ));
9494 mMaxAttentionWindow = std::max (mMaxAttentionWindow , mMaxAttentionWindowVec .back ());
@@ -112,8 +112,8 @@ class TrtGptModel : public executor::Model
112112 mMaxAttentionWindow = mMaxSequenceLen ;
113113 }
114114
115- mSinkTokenLen = optionalParams. kvCacheConfig . sinkTokenLength .has_value ()
116- ? optionalParams. kvCacheConfig . sinkTokenLength .value ()
115+ mSinkTokenLen = executorConfig. getKvCacheConfig (). getSinkTokenLength () .has_value ()
116+ ? executorConfig. getKvCacheConfig (). getSinkTokenLength () .value ()
117117 : 0 ;
118118
119119 mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism ();
@@ -136,26 +136,26 @@ class TrtGptModel : public executor::Model
136136 TLLM_LOG_INFO (" TRTGptModel normalizeLogProbs: %d" , mNormalizeLogProbs );
137137
138138 mMaxNumTokens = modelConfig.getMaxNumTokens ();
139- if (optionalParams. maxNumTokens && mMaxNumTokens )
139+ if (executorConfig. getMaxNumTokens (). has_value () && mMaxNumTokens )
140140 {
141- if (optionalParams. maxNumTokens .value () > mMaxNumTokens .value ())
141+ if (executorConfig. getMaxNumTokens () .value () > mMaxNumTokens .value ())
142142 {
143143 TLLM_LOG_WARNING (
144144 " Runtime configured max num tokens (%d) is larger than model max num tokens (%d) and will be "
145145 " ignored." ,
146- optionalParams. maxNumTokens .value (), mMaxNumTokens .value ());
146+ executorConfig. getMaxNumTokens () .value (), mMaxNumTokens .value ());
147147 }
148148 else
149149 {
150- mMaxNumTokens = optionalParams. maxNumTokens ;
150+ mMaxNumTokens = executorConfig. getMaxNumTokens () ;
151151 }
152152 }
153153 if (mMaxNumTokens )
154154 {
155155 TLLM_LOG_INFO (" TRTGptModel maxNumTokens: %d" , mMaxNumTokens .value ());
156156 }
157157
158- if (optionalParams. enableChunkedContext )
158+ if (executorConfig. getEnableChunkedContext () )
159159 {
160160 mMaxInputLen = mMaxSequenceLen - 1 ;
161161 TLLM_LOG_INFO (
@@ -199,9 +199,9 @@ class TrtGptModel : public executor::Model
199199 using tensorrt_llm::common::stl_utils::toString;
200200
201201 TLLM_LOG_INFO (" Capacity Scheduler Policy: %s" ,
202- toString (optionalParams. schedulerConfig .getCapacitySchedulerPolicy ()).c_str ());
202+ toString (executorConfig. getSchedulerConfig () .getCapacitySchedulerPolicy ()).c_str ());
203203 TLLM_LOG_INFO (" Context Chunking Scheduler Policy: %s" ,
204- toString (optionalParams. schedulerConfig .getContextChunkingPolicy ()).c_str ());
204+ toString (executorConfig. getSchedulerConfig () .getContextChunkingPolicy ()).c_str ());
205205 }
206206
207207 [[nodiscard]] std::optional<SizeType32> getMaxNumTokens () const
0 commit comments