Skip to content

Commit 3e035f2

Browse files
authored
v1.2 (#3082)
Signed-off-by: wili <wili@nvidia.com>
1 parent d9522c5 commit 3e035f2

32 files changed

+703
-316
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ class SamplingConfig
7575
std::optional<SizeType32> const& earlyStopping = std::nullopt,
7676
std::optional<SizeType32> const& noRepeatNgramSize = std::nullopt,
7777
std::optional<SizeType32> const& numReturnSequences = std::nullopt,
78-
std::optional<FloatType> const& minP = std::nullopt);
78+
std::optional<FloatType> const& minP = std::nullopt,
79+
std::optional<std::vector<SizeType32>> const& beamWidthArray = std::nullopt);
7980

8081
bool operator==(SamplingConfig const& other) const;
8182

@@ -100,6 +101,7 @@ class SamplingConfig
100101
[[nodiscard]] std::optional<SizeType32> getNoRepeatNgramSize() const;
101102
[[nodiscard]] std::optional<SizeType32> getNumReturnSequences() const;
102103
[[nodiscard]] std::optional<FloatType> getMinP() const;
104+
[[nodiscard]] std::optional<std::vector<SizeType32>> getBeamWidthArray() const;
103105

104106
void setBeamWidth(SizeType32 beamWidth);
105107
void setTopK(std::optional<SizeType32> const& topK);
@@ -121,6 +123,7 @@ class SamplingConfig
121123
void setNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
122124
void setNumReturnSequences(std::optional<SizeType32> const& numReturnSequences);
123125
void setMinP(std::optional<FloatType> const& minP);
126+
void setBeamWidthArray(std::optional<std::vector<SizeType32>> const& beamWidthArray);
124127

125128
private:
126129
static SizeType32 checkBeamWidth(SizeType32 beamWidth);
@@ -130,15 +133,18 @@ class SamplingConfig
130133
static std::optional<TokenIdType> const& checkTopPResetIds(std::optional<TokenIdType> const& topPResetIds);
131134
static std::optional<FloatType> const& checkTopPDecay(std::optional<FloatType> const& topPDecay);
132135
static std::optional<FloatType> const& checkTemperature(std::optional<FloatType> const& temperature);
133-
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& penalty);
134136
static std::optional<SizeType32> const& checkMinTokens(std::optional<SizeType32> const& minTokens);
135-
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
136137
static std::optional<FloatType> const& checkBeamSearchDiversityRate(
137138
std::optional<FloatType> const& beamSearchDiversityRate);
139+
static std::optional<FloatType> const& checkRepetitionPenalty(std::optional<FloatType> const& repetitionpenalty);
140+
static std::optional<FloatType> const& checkLengthPenalty(std::optional<FloatType> const& lengthPenalty);
141+
static std::optional<SizeType32> const& checkEarlyStopping(std::optional<SizeType32> const& earlyStopping);
142+
static std::optional<SizeType32> const& checkNoRepeatNgramSize(std::optional<SizeType32> const& noRepeatNgramSize);
138143
static std::optional<SizeType32> const& checkNumReturnSequences(
139144
std::optional<SizeType32> const& numReturnSequences, SizeType32 beamWidth);
140145
static std::optional<FloatType> const& checkMinP(std::optional<FloatType> const& minP);
141-
146+
static std::optional<std::vector<SizeType32>> const& checkBeamWidthArray(
147+
std::optional<std::vector<SizeType32>> const& beamWidthArray, std::optional<SizeType32> const beamWidth);
142148
void updateNumReturnBeams();
143149

144150
friend class Serialization;
@@ -188,6 +194,8 @@ class SamplingConfig
188194
/// @brief Controls the min_p scaling for sampling.
189195
/// It masks x which P_x < min_p * P_max, where P_x is probability of candidate x. Default is 0.f
190196
std::optional<FloatType> mMinP;
197+
/// @brief Controls the beam width for each step for Variable-Beam-Width-Search.
198+
std::optional<std::vector<SizeType32>> mBeamWidthArray;
191199
};
192200

193201
/// @brief Configuration that controls the outputs of a Result

cpp/include/tensorrt_llm/layers/defaultDecodingParams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ class DefaultDecodingParams
128128
{
129129
return 0.0f;
130130
}
131+
132+
[[nodiscard]] static std::vector<runtime::SizeType32> getBeamWidthArray()
133+
{
134+
return std::vector<runtime::SizeType32>{1};
135+
}
131136
};
132137
} // namespace layers
133138
} // namespace tensorrt_llm

cpp/include/tensorrt_llm/runtime/samplingConfig.h

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,6 @@ class SamplingConfig
7474
}
7575
}
7676

77-
template <typename T>
78-
using Vec = std::vector<T>;
79-
8077
template <typename T>
8178
bool validateVec(std::string name, OptVec<T> const& vec, T min, std::optional<T> max = std::nullopt)
8279
{
@@ -185,6 +182,9 @@ class SamplingConfig
185182
configs, [&configs](size_t ci) { return configs[ci].outputLogProbs; }, false);
186183
cumLogProbs = fuseValues<bool>(
187184
configs, [&configs](size_t ci) { return configs[ci].cumLogProbs; }, false);
185+
beamWidthArray = fuseValues<std::vector<SizeType32>>(
186+
configs, [&configs](size_t ci) { return configs[ci].beamWidthArray; },
187+
layers::DefaultDecodingParams::getBeamWidthArray());
188188
// Only used for tests.
189189
draftAcceptanceThreshold = fuseValues<FloatType>(
190190
configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold; }, 0);
@@ -193,22 +193,22 @@ class SamplingConfig
193193
}
194194

195195
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
196-
std::optional<executor::ExternalDraftTokensConfig> const& externalDraftTokensConfig)
196+
std::optional<executor::ExternalDraftTokensConfig> const& externalDraftTokensConfig = std::nullopt)
197197
: beamWidth{samplingConfig.getBeamWidth()}
198198
, numReturnSequences(samplingConfig.getNumReturnSequences())
199199
{
200200

201201
if (externalDraftTokensConfig && externalDraftTokensConfig.value().getAcceptanceThreshold())
202202
{
203203
draftAcceptanceThreshold
204-
= Vec<FloatType>{externalDraftTokensConfig.value().getAcceptanceThreshold().value()};
204+
= std::vector<FloatType>{externalDraftTokensConfig.value().getAcceptanceThreshold().value()};
205205
}
206206

207207
#define SET_FROM_OPTIONAL(varName, VarName, VarType) \
208208
\
209209
if (samplingConfig.get##VarName()) \
210210
{ \
211-
varName = Vec<VarType>{samplingConfig.get##VarName().value()}; \
211+
varName = std::vector<VarType>{samplingConfig.get##VarName().value()}; \
212212
}
213213

214214
SET_FROM_OPTIONAL(topK, TopK, SizeType32)
@@ -228,6 +228,7 @@ class SamplingConfig
228228
SET_FROM_OPTIONAL(earlyStopping, EarlyStopping, SizeType32)
229229
SET_FROM_OPTIONAL(noRepeatNgramSize, NoRepeatNgramSize, SizeType32)
230230
SET_FROM_OPTIONAL(minP, MinP, FloatType)
231+
SET_FROM_OPTIONAL(beamWidthArray, BeamWidthArray, std::vector<SizeType32>)
231232
#undef SET_FROM_OPTIONAL
232233
}
233234

@@ -266,16 +267,18 @@ class SamplingConfig
266267
valid &= validateVec("topK", topK, -1);
267268
valid &= validateVec("topP", topP, -fltEpsilon, {1.f});
268269
valid &= validateVec("topPMin", topPMin, 0.f, {1.f});
269-
valid &= validateVec("topPDecay", topPDecay, 0.f, {1.f});
270270
valid &= validateVec("topPResetIds", topPResetIds, -1);
271-
271+
valid &= validateVec("topPDecay", topPDecay, 0.f, {1.f});
272272
valid &= validateVec("temperature", temperature, -fltEpsilon);
273-
valid &= validateVec("repetitionPenalty", repetitionPenalty, 0.f);
274273
valid &= validateVec("minLength", minLength, -1);
274+
valid &= validateVec("beamSearchDiversityRate", beamSearchDiversityRate, -fltEpsilon);
275+
valid &= validateVec("repetitionPenalty", repetitionPenalty, 0.f);
276+
// TODO: checking `lengthPenalty`leads to a failure in
277+
// `test_openai_chat_example`, debug and re-enable it later.
278+
// valid &= validateVec("lengthPenalty", lengthPenalty, 0.f);
275279
valid &= validateVec("noRepeatNgramSize", noRepeatNgramSize, 0);
276280
valid &= validateVec("minP", minP, -fltEpsilon, {1.f});
277-
278-
valid &= validateVec("beamSearchDiversityRate", beamSearchDiversityRate, -fltEpsilon);
281+
// TODO: check `beamWidthArray`
279282

280283
// Detect greedy sampling and overwrite params.
281284
if (temperature)
@@ -332,38 +335,39 @@ class SamplingConfig
332335
SizeType32 beamWidth;
333336
std::optional<SizeType32> numReturnSequences;
334337

335-
// penalties
336-
OptVec<FloatType> temperature; // [1] or [batch_size] on cpu
337-
OptVec<FloatType> originalTemperature; // [1] or [batch_size] on cpu
338-
OptVec<SizeType32> minLength; // [1] or [batch_size] on cpu
339-
OptVec<FloatType> repetitionPenalty; // [1] or [batch_size] on cpu
340-
OptVec<FloatType> presencePenalty; // [1] or [batch_size] on cpu
341-
OptVec<FloatType> frequencyPenalty; // [1] or [batch_size] on cpu
342-
OptVec<SizeType32> noRepeatNgramSize; // [1] or [batch_size] on cpu
338+
// penalties, [1] for one request, [batchSize] for one batch, the same for other parameters below
339+
OptVec<FloatType> temperature; // [1] or [batchSize]
340+
OptVec<FloatType> originalTemperature; // [1] or [batchSize]
341+
OptVec<SizeType32> minLength; // [1] or [batchSize]
342+
OptVec<FloatType> repetitionPenalty; // [1] or [batchSize]
343+
OptVec<FloatType> presencePenalty; // [1] or [batchSize]
344+
OptVec<FloatType> frequencyPenalty; // [1] or [batchSize]
345+
OptVec<SizeType32> noRepeatNgramSize; // [1] or [batchSize]
343346

344347
// probs
345348
OptVec<bool> outputLogProbs;
346349
OptVec<bool> cumLogProbs;
347350

348351
// sampling layers
349-
OptVec<SizeType32> topK; // [1] or [batch_size] on cpu
350-
OptVec<FloatType> topP; // [1] or [batch_size] on cpu
351-
OptVec<uint64_t> randomSeed; // [1] or [batch_size] on cpu
352-
OptVec<FloatType> topPDecay; // [batch_size], must between [0, 1]
353-
OptVec<FloatType> topPMin; // [batch_size], must between [0, 1]
354-
OptVec<TokenIdType> topPResetIds; // [batch_size]
355-
OptVec<FloatType> minP; // [1] or [batch_size] on cpu
352+
OptVec<SizeType32> topK; // [1] or [batchSize]
353+
OptVec<FloatType> topP; // [1] or [batchSize]
354+
OptVec<uint64_t> randomSeed; // [1] or [batchSize]
355+
OptVec<FloatType> topPDecay; // [1] or [batchSize], between [0, 1]
356+
OptVec<FloatType> topPMin; // [1] or [batchSize], between [0, 1]
357+
OptVec<TokenIdType> topPResetIds; // [1] or [batchSize]
358+
OptVec<FloatType> minP; // [1] or [batchSize]
356359

357360
// beam search layer
358-
OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
359-
OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
360-
OptVec<SizeType32> earlyStopping; // [1] or [batch_size]
361+
OptVec<FloatType> beamSearchDiversityRate; // [1] or [batchSize]
362+
OptVec<FloatType> lengthPenalty; // [1] or [batchSize]
363+
OptVec<SizeType32> earlyStopping; // [1] or [batchSize]
364+
OptVec<std::vector<SizeType32>> beamWidthArray; // [maxBeamWidthArrayLength] or [batchSize, maxBeamWidthArrayLength]
361365

362366
// speculative decoding, only the first value is used (in gptDecoderBatched.cpp)
363-
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
367+
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batchSize]
364368

365369
// medusa params
366-
OptVec<std::vector<runtime::SizeType32>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
370+
OptVec<std::vector<SizeType32>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
367371

368372
std::optional<bool> normalizeLogProbs;
369373

@@ -379,7 +383,7 @@ class SamplingConfig
379383
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
380384
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
381385
&& normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
382-
&& cumLogProbs == other.cumLogProbs && minP == other.minP;
386+
&& cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other.beamWidthArray;
383387
}
384388

385389
SizeType32 getNumReturnBeams() const

cpp/tensorrt_llm/batch_manager/trtGptModelV1.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ void addToSamplingConfig(SamplingConfig& batchSamplingConfig, SamplingConfig con
120120
TLLM_CHECK(batchSamplingConfig.beamSearchDiversityRate == addSamplingConfig.beamSearchDiversityRate);
121121
TLLM_CHECK(batchSamplingConfig.lengthPenalty == addSamplingConfig.lengthPenalty);
122122
TLLM_CHECK(batchSamplingConfig.earlyStopping == addSamplingConfig.earlyStopping);
123+
TLLM_CHECK(batchSamplingConfig.beamWidthArray == addSamplingConfig.beamWidthArray);
123124

124125
auto addOptional = [](auto& batch, auto const& add, char const* name)
125126
{

0 commit comments

Comments
 (0)