Skip to content

Commit 11f6a08

Browse files
authored
Add support for min_length and min_new_tokens generation parameters (#308)
* Add support for `MinNewTokensLengthLogitsProcessor` * Add support for `MinLengthLogitsProcessor` * Fix `generation_config` defaults * Fix `input_ids_seq_length` * Add unit tests for generation * Fix generation parameters test case * Allow specification of multiple `eos_token_ids`
1 parent ef27100 commit 11f6a08

File tree

3 files changed

+271
-30
lines changed

3 files changed

+271
-30
lines changed

src/models.js

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ import {
6464
WhisperTimeStampLogitsProcessor,
6565
NoRepeatNGramLogitsProcessor,
6666
RepetitionPenaltyLogitsProcessor,
67+
MinLengthLogitsProcessor,
68+
MinNewTokensLengthLogitsProcessor,
6769

6870
Sampler,
6971
} from './utils/generation.js';
@@ -678,6 +680,7 @@ export class PreTrainedModel extends Callable {
678680
info = await Promise.all([
679681
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
680682
constructSession(pretrained_model_name_or_path, options.model_file_name ?? 'decoder_model_merged', options),
683+
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
681684
]);
682685

683686
} else if (modelType === MODEL_TYPES.Seq2Seq || modelType === MODEL_TYPES.Vision2Seq) {
@@ -782,17 +785,17 @@ export class PreTrainedModel extends Callable {
782785
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
783786
// }
784787

785-
// if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
786-
// processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
787-
// }
788+
if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
789+
processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
790+
}
788791

789-
// if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
790-
// processors.push(new MinNewTokensLengthLogitsProcessor(
791-
// input_ids_seq_length,
792-
// generation_config.min_new_tokens,
793-
// generation_config.eos_token_id
794-
// ));
795-
// }
792+
if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
793+
processors.push(new MinNewTokensLengthLogitsProcessor(
794+
input_ids_seq_length,
795+
generation_config.min_new_tokens,
796+
generation_config.eos_token_id
797+
));
798+
}
796799

797800
// if (prefix_allowed_tokens_fn !== null) {
798801
// processors.push(new PrefixConstrainedLogitsProcessor(
@@ -866,7 +869,8 @@ export class PreTrainedModel extends Callable {
866869
*/
867870
_get_generation_config(generation_config) {
868871
// Create empty generation config (contains defaults)
869-
let gen_config = new GenerationConfig();
872+
// We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
873+
let gen_config = new GenerationConfig(this.config);
870874

871875
// Apply model's generation config, if it exists
872876
if ('generation_config' in this) {
@@ -928,7 +932,7 @@ export class PreTrainedModel extends Callable {
928932
input_ids_seq_length = 0;
929933

930934
} else {
931-
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims[0] : inputs.length;
935+
input_ids_seq_length = inputs instanceof Tensor ? inputs.dims.at(-1) : inputs.length;
932936

933937
// decoder-only
934938
if (input_ids_seq_length === 0) {
@@ -948,6 +952,12 @@ export class PreTrainedModel extends Callable {
948952
logits_processor
949953
)
950954

955+
/** @type {number[]} */
956+
let eos_token_ids = generation_config.eos_token_id;
957+
if (eos_token_ids !== null && !Array.isArray(eos_token_ids)) {
958+
eos_token_ids = [eos_token_ids];
959+
}
960+
951961
// TODO implement early_stopping
952962
// https://huggingface.co/blog/how-to-generate
953963

@@ -1007,7 +1017,7 @@ export class PreTrainedModel extends Callable {
10071017

10081018
newBeam.score += logProb;
10091019

1010-
if (newTokenId === this.config.eos_token_id) {
1020+
if (eos_token_ids && eos_token_ids.includes(newTokenId)) {
10111021
newBeam.done = true;
10121022
}
10131023

@@ -2476,10 +2486,12 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
24762486
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
24772487
* @param {Object} session The ONNX session containing the encoder model.
24782488
* @param {any} decoder_merged_session The ONNX session containing the merged decoder model.
2489+
* @param {Object} generation_config Configuration object for the generation process.
24792490
*/
2480-
constructor(config, session, decoder_merged_session) {
2491+
constructor(config, session, decoder_merged_session, generation_config) {
24812492
super(config, session);
24822493
this.decoder_merged_session = decoder_merged_session;
2494+
this.generation_config = generation_config;
24832495

24842496
this.num_layers = this.config.decoder.n_layer;
24852497
this.num_heads = this.config.decoder.n_head;
@@ -2617,9 +2629,11 @@ export class GPT2PreTrainedModel extends PreTrainedModel {
26172629
* Creates a new instance of the `GPT2PreTrainedModel` class.
26182630
* @param {Object} config The configuration of the model.
26192631
* @param {any} session The ONNX session containing the model weights.
2632+
* @param {GenerationConfig} generation_config The generation configuration.
26202633
*/
2621-
constructor(config, session) {
2634+
constructor(config, session, generation_config) {
26222635
super(config, session);
2636+
this.generation_config = generation_config;
26232637

26242638
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
26252639
this.config.pad_token_id = this.config.eos_token_id
@@ -2649,9 +2663,11 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel {
26492663
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
26502664
* @param {Object} config The configuration of the model.
26512665
* @param {any} session The ONNX session containing the model weights.
2666+
* @param {GenerationConfig} generation_config The generation configuration.
26522667
*/
2653-
constructor(config, session) {
2668+
constructor(config, session, generation_config) {
26542669
super(config, session);
2670+
this.generation_config = generation_config;
26552671

26562672
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
26572673
this.config.pad_token_id = this.config.eos_token_id
@@ -2673,9 +2689,11 @@ export class GPTNeoXPreTrainedModel extends PreTrainedModel {
26732689
* Creates a new instance of the `GPTNeoXPreTrainedModel` class.
26742690
* @param {Object} config The configuration of the model.
26752691
* @param {any} session The ONNX session containing the model weights.
2692+
* @param {GenerationConfig} generation_config The generation configuration.
26762693
*/
2677-
constructor(config, session) {
2694+
constructor(config, session, generation_config) {
26782695
super(config, session);
2696+
this.generation_config = generation_config;
26792697

26802698
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
26812699
this.config.pad_token_id = this.config.eos_token_id
@@ -2698,9 +2716,11 @@ export class GPTJPreTrainedModel extends PreTrainedModel {
26982716
* Creates a new instance of the `GPTJPreTrainedModel` class.
26992717
* @param {Object} config The configuration of the model.
27002718
* @param {any} session The ONNX session containing the model weights.
2719+
* @param {GenerationConfig} generation_config The generation configuration.
27012720
*/
2702-
constructor(config, session) {
2721+
constructor(config, session, generation_config) {
27032722
super(config, session);
2723+
this.generation_config = generation_config;
27042724

27052725
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
27062726
this.config.pad_token_id = this.config.eos_token_id
@@ -2724,9 +2744,11 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel {
27242744
* Creates a new instance of the `GPTBigCodePreTrainedModel` class.
27252745
* @param {Object} config The configuration of the model.
27262746
* @param {any} session The ONNX session containing the model weights.
2747+
* @param {GenerationConfig} generation_config The generation configuration.
27272748
*/
2728-
constructor(config, session) {
2749+
constructor(config, session, generation_config) {
27292750
super(config, session);
2751+
this.generation_config = generation_config;
27302752

27312753
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
27322754
this.config.pad_token_id = this.config.eos_token_id
@@ -2747,11 +2769,13 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
27472769
export class CodeGenPreTrainedModel extends PreTrainedModel {
27482770
/**
27492771
* Creates a new instance of the `CodeGenPreTrainedModel` class.
2750-
* @param {Object} config The model configuration object.
2751-
* @param {Object} session The ONNX session object.
2752-
*/
2753-
constructor(config, session) {
2772+
* @param {Object} config The model configuration object.
2773+
* @param {Object} session The ONNX session object.
2774+
* @param {GenerationConfig} generation_config The generation configuration.
2775+
*/
2776+
constructor(config, session, generation_config) {
27542777
super(config, session);
2778+
this.generation_config = generation_config;
27552779

27562780
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
27572781
this.config.pad_token_id = this.config.eos_token_id
@@ -2785,11 +2809,13 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
27852809
export class LlamaPreTrainedModel extends PreTrainedModel {
27862810
/**
27872811
* Creates a new instance of the `LlamaPreTrainedModel` class.
2788-
* @param {Object} config The model configuration object.
2789-
* @param {Object} session The ONNX session object.
2790-
*/
2791-
constructor(config, session) {
2812+
* @param {Object} config The model configuration object.
2813+
* @param {Object} session The ONNX session object.
2814+
* @param {GenerationConfig} generation_config The generation configuration.
2815+
*/
2816+
constructor(config, session, generation_config) {
27922817
super(config, session);
2818+
this.generation_config = generation_config;
27932819

27942820
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
27952821
this.config.pad_token_id = this.config.eos_token_id
@@ -2817,9 +2843,11 @@ export class BloomPreTrainedModel extends PreTrainedModel {
28172843
* Creates a new instance of the `BloomPreTrainedModel` class.
28182844
* @param {Object} config The configuration of the model.
28192845
* @param {any} session The ONNX session containing the model weights.
2846+
* @param {GenerationConfig} generation_config The generation configuration.
28202847
*/
2821-
constructor(config, session) {
2848+
constructor(config, session, generation_config) {
28222849
super(config, session);
2850+
this.generation_config = generation_config;
28232851

28242852
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
28252853
this.config.pad_token_id = this.config.eos_token_id
@@ -2848,9 +2876,11 @@ export class MptPreTrainedModel extends PreTrainedModel {
28482876
* Creates a new instance of the `MptPreTrainedModel` class.
28492877
* @param {Object} config The model configuration object.
28502878
* @param {Object} session The ONNX session object.
2879+
* @param {GenerationConfig} generation_config The generation configuration.
28512880
*/
2852-
constructor(config, session) {
2881+
constructor(config, session, generation_config) {
28532882
super(config, session);
2883+
this.generation_config = generation_config;
28542884

28552885
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
28562886
this.config.pad_token_id = this.config.eos_token_id
@@ -2880,9 +2910,11 @@ export class OPTPreTrainedModel extends PreTrainedModel {
28802910
* Creates a new instance of the `OPTPreTrainedModel` class.
28812911
* @param {Object} config The model configuration object.
28822912
* @param {Object} session The ONNX session object.
2913+
* @param {GenerationConfig} generation_config The generation configuration.
28832914
*/
2884-
constructor(config, session) {
2915+
constructor(config, session, generation_config) {
28852916
super(config, session);
2917+
this.generation_config = generation_config;
28862918

28872919
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
28882920
this.config.pad_token_id = this.config.eos_token_id

src/utils/generation.js

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,77 @@ export class RepetitionPenaltyLogitsProcessor extends LogitsProcessor {
420420
}
421421
}
422422

423+
/**
424+
* A logits processor that enforces a minimum number of tokens.
425+
*
426+
* @extends LogitsProcessor
427+
*/
428+
export class MinLengthLogitsProcessor extends LogitsProcessor {
429+
/**
430+
* Create a MinLengthLogitsProcessor.
431+
* @param {number} min_length The minimum length below which the score of `eos_token_id` is set to negative infinity.
432+
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
433+
*/
434+
constructor(min_length, eos_token_id) {
435+
super();
436+
this.min_length = min_length;
437+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
438+
}
439+
440+
/**
441+
* Apply logit processor.
442+
* @param {Array} input_ids The input IDs.
443+
* @param {Object} logits The logits.
444+
* @returns {Object} The processed logits.
445+
*/
446+
_call(input_ids, logits) {
447+
if (input_ids.length < this.min_length) {
448+
for (const eos_token of this.eos_token_id) {
449+
logits.data[eos_token] = -Infinity;
450+
}
451+
}
452+
453+
return logits
454+
}
455+
}
456+
457+
/**
458+
* A logits processor that enforces a minimum number of new tokens.
459+
*
460+
* @extends LogitsProcessor
461+
*/
462+
export class MinNewTokensLengthLogitsProcessor extends LogitsProcessor {
463+
/**
464+
* Create a MinNewTokensLengthLogitsProcessor.
465+
* @param {number} prompt_length_to_skip The input tokens length.
466+
* @param {number} min_new_tokens The minimum *new* tokens length below which the score of `eos_token_id` is set to negative infinity.
467+
* @param {number|number[]} eos_token_id The ID/IDs of the end-of-sequence token.
468+
*/
469+
constructor(prompt_length_to_skip, min_new_tokens, eos_token_id) {
470+
super();
471+
this.prompt_length_to_skip = prompt_length_to_skip;
472+
this.min_new_tokens = min_new_tokens;
473+
this.eos_token_id = Array.isArray(eos_token_id) ? eos_token_id : [eos_token_id];
474+
}
475+
476+
/**
477+
* Apply logit processor.
478+
* @param {Array} input_ids The input IDs.
479+
* @param {Object} logits The logits.
480+
* @returns {Object} The processed logits.
481+
*/
482+
_call(input_ids, logits) {
483+
const new_tokens_length = input_ids.length - this.prompt_length_to_skip;
484+
if (new_tokens_length < this.min_new_tokens) {
485+
for (const eos_token of this.eos_token_id) {
486+
logits.data[eos_token] = -Infinity;
487+
}
488+
}
489+
490+
return logits
491+
}
492+
}
493+
423494
/**
424495
* Class that holds a configuration for a generation task.
425496
*/

0 commit comments

Comments
 (0)