@@ -74,9 +74,6 @@ class SamplingConfig
7474 }
7575 }
7676
77- template <typename T>
78- using Vec = std::vector<T>;
79-
8077 template <typename T>
8178 bool validateVec (std::string name, OptVec<T> const & vec, T min, std::optional<T> max = std::nullopt )
8279 {
@@ -185,6 +182,9 @@ class SamplingConfig
185182 configs, [&configs](size_t ci) { return configs[ci].outputLogProbs ; }, false );
186183 cumLogProbs = fuseValues<bool >(
187184 configs, [&configs](size_t ci) { return configs[ci].cumLogProbs ; }, false );
185+ beamWidthArray = fuseValues<std::vector<SizeType32>>(
186+ configs, [&configs](size_t ci) { return configs[ci].beamWidthArray ; },
187+ layers::DefaultDecodingParams::getBeamWidthArray ());
188188 // Only used for tests.
189189 draftAcceptanceThreshold = fuseValues<FloatType>(
190190 configs, [&configs](size_t ci) { return configs[ci].draftAcceptanceThreshold ; }, 0 );
@@ -193,22 +193,22 @@ class SamplingConfig
193193 }
194194
195195 explicit SamplingConfig (executor::SamplingConfig const & samplingConfig,
196- std::optional<executor::ExternalDraftTokensConfig> const & externalDraftTokensConfig)
196+ std::optional<executor::ExternalDraftTokensConfig> const & externalDraftTokensConfig = std:: nullopt )
197197 : beamWidth{samplingConfig.getBeamWidth ()}
198198 , numReturnSequences(samplingConfig.getNumReturnSequences())
199199 {
200200
201201 if (externalDraftTokensConfig && externalDraftTokensConfig.value ().getAcceptanceThreshold ())
202202 {
203203 draftAcceptanceThreshold
204- = Vec <FloatType>{externalDraftTokensConfig.value ().getAcceptanceThreshold ().value ()};
204+ = std::vector <FloatType>{externalDraftTokensConfig.value ().getAcceptanceThreshold ().value ()};
205205 }
206206
207207#define SET_FROM_OPTIONAL (varName, VarName, VarType ) \
208208 \
209209 if (samplingConfig.get ##VarName ()) \
210210 { \
211- varName = Vec <VarType>{samplingConfig.get ##VarName ().value ()}; \
211+ varName = std::vector <VarType>{samplingConfig.get ##VarName ().value ()}; \
212212 }
213213
214214 SET_FROM_OPTIONAL (topK, TopK, SizeType32)
@@ -228,6 +228,7 @@ class SamplingConfig
228228 SET_FROM_OPTIONAL (earlyStopping, EarlyStopping, SizeType32)
229229 SET_FROM_OPTIONAL (noRepeatNgramSize, NoRepeatNgramSize, SizeType32)
230230 SET_FROM_OPTIONAL (minP, MinP, FloatType)
231+ SET_FROM_OPTIONAL (beamWidthArray, BeamWidthArray, std::vector<SizeType32>)
231232#undef SET_FROM_OPTIONAL
232233 }
233234
@@ -266,16 +267,18 @@ class SamplingConfig
266267 valid &= validateVec (" topK" , topK, -1 );
267268 valid &= validateVec (" topP" , topP, -fltEpsilon, {1 .f });
268269 valid &= validateVec (" topPMin" , topPMin, 0 .f , {1 .f });
269- valid &= validateVec (" topPDecay" , topPDecay, 0 .f , {1 .f });
270270 valid &= validateVec (" topPResetIds" , topPResetIds, -1 );
271-
271+ valid &= validateVec ( " topPDecay " , topPDecay, 0 . f , { 1 . f });
272272 valid &= validateVec (" temperature" , temperature, -fltEpsilon);
273- valid &= validateVec (" repetitionPenalty" , repetitionPenalty, 0 .f );
274273 valid &= validateVec (" minLength" , minLength, -1 );
274+ valid &= validateVec (" beamSearchDiversityRate" , beamSearchDiversityRate, -fltEpsilon);
275+ valid &= validateVec (" repetitionPenalty" , repetitionPenalty, 0 .f );
276+ // TODO: checking `lengthPenalty`leads to a failure in
277+ // `test_openai_chat_example`, debug and re-enable it later.
278+ // valid &= validateVec("lengthPenalty", lengthPenalty, 0.f);
275279 valid &= validateVec (" noRepeatNgramSize" , noRepeatNgramSize, 0 );
276280 valid &= validateVec (" minP" , minP, -fltEpsilon, {1 .f });
277-
278- valid &= validateVec (" beamSearchDiversityRate" , beamSearchDiversityRate, -fltEpsilon);
281+ // TODO: check `beamWidthArray`
279282
280283 // Detect greedy sampling and overwrite params.
281284 if (temperature)
@@ -332,38 +335,39 @@ class SamplingConfig
332335 SizeType32 beamWidth;
333336 std::optional<SizeType32> numReturnSequences;
334337
335- // penalties
336- OptVec<FloatType> temperature; // [1] or [batch_size] on cpu
337- OptVec<FloatType> originalTemperature; // [1] or [batch_size] on cpu
338- OptVec<SizeType32> minLength; // [1] or [batch_size] on cpu
339- OptVec<FloatType> repetitionPenalty; // [1] or [batch_size] on cpu
340- OptVec<FloatType> presencePenalty; // [1] or [batch_size] on cpu
341- OptVec<FloatType> frequencyPenalty; // [1] or [batch_size] on cpu
342- OptVec<SizeType32> noRepeatNgramSize; // [1] or [batch_size] on cpu
338+ // penalties, [1] for one request, [batchSize] for one batch, the same for other parameters below
339+ OptVec<FloatType> temperature; // [1] or [batchSize]
340+ OptVec<FloatType> originalTemperature; // [1] or [batchSize]
341+ OptVec<SizeType32> minLength; // [1] or [batchSize]
342+ OptVec<FloatType> repetitionPenalty; // [1] or [batchSize]
343+ OptVec<FloatType> presencePenalty; // [1] or [batchSize]
344+ OptVec<FloatType> frequencyPenalty; // [1] or [batchSize]
345+ OptVec<SizeType32> noRepeatNgramSize; // [1] or [batchSize]
343346
344347 // probs
345348 OptVec<bool > outputLogProbs;
346349 OptVec<bool > cumLogProbs;
347350
348351 // sampling layers
349- OptVec<SizeType32> topK; // [1] or [batch_size] on cpu
350- OptVec<FloatType> topP; // [1] or [batch_size] on cpu
351- OptVec<uint64_t > randomSeed; // [1] or [batch_size] on cpu
352- OptVec<FloatType> topPDecay; // [batch_size], must between [0, 1]
353- OptVec<FloatType> topPMin; // [batch_size], must between [0, 1]
354- OptVec<TokenIdType> topPResetIds; // [batch_size ]
355- OptVec<FloatType> minP; // [1] or [batch_size] on cpu
352+ OptVec<SizeType32> topK; // [1] or [batchSize]
353+ OptVec<FloatType> topP; // [1] or [batchSize]
354+ OptVec<uint64_t > randomSeed; // [1] or [batchSize]
355+ OptVec<FloatType> topPDecay; // [1] or [batchSize], between [0, 1]
356+ OptVec<FloatType> topPMin; // [1] or [batchSize], between [0, 1]
357+ OptVec<TokenIdType> topPResetIds; // [1] or [batchSize ]
358+ OptVec<FloatType> minP; // [1] or [batchSize]
356359
357360 // beam search layer
358- OptVec<FloatType> beamSearchDiversityRate; // [1] or [batch_size]
359- OptVec<FloatType> lengthPenalty; // [1] or [batch_size]
360- OptVec<SizeType32> earlyStopping; // [1] or [batch_size]
361+ OptVec<FloatType> beamSearchDiversityRate; // [1] or [batchSize]
362+ OptVec<FloatType> lengthPenalty; // [1] or [batchSize]
363+ OptVec<SizeType32> earlyStopping; // [1] or [batchSize]
364+ OptVec<std::vector<SizeType32>> beamWidthArray; // [maxBeamWidthArrayLength] or [batchSize, maxBeamWidthArrayLength]
361365
362366 // speculative decoding, only the first value is used (in gptDecoderBatched.cpp)
363- OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size ]
367+ OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batchSize ]
364368
365369 // medusa params
366- OptVec<std::vector<runtime:: SizeType32>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
370+ OptVec<std::vector<SizeType32>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
367371
368372 std::optional<bool > normalizeLogProbs;
369373
@@ -379,7 +383,7 @@ class SamplingConfig
379383 && lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
380384 && draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
381385 && normalizeLogProbs == other.normalizeLogProbs && outputLogProbs == other.outputLogProbs
382- && cumLogProbs == other.cumLogProbs && minP == other.minP ;
386+ && cumLogProbs == other.cumLogProbs && minP == other.minP && beamWidthArray == other. beamWidthArray ;
383387 }
384388
385389 SizeType32 getNumReturnBeams () const
0 commit comments