@@ -64,6 +64,8 @@ import {
64
64
WhisperTimeStampLogitsProcessor ,
65
65
NoRepeatNGramLogitsProcessor ,
66
66
RepetitionPenaltyLogitsProcessor ,
67
+ MinLengthLogitsProcessor ,
68
+ MinNewTokensLengthLogitsProcessor ,
67
69
68
70
Sampler ,
69
71
} from './utils/generation.js' ;
@@ -678,6 +680,7 @@ export class PreTrainedModel extends Callable {
678
680
info = await Promise . all ( [
679
681
AutoConfig . from_pretrained ( pretrained_model_name_or_path , options ) ,
680
682
constructSession ( pretrained_model_name_or_path , options . model_file_name ?? 'decoder_model_merged' , options ) ,
683
+ getModelJSON ( pretrained_model_name_or_path , 'generation_config.json' , false , options ) ,
681
684
] ) ;
682
685
683
686
} else if ( modelType === MODEL_TYPES . Seq2Seq || modelType === MODEL_TYPES . Vision2Seq ) {
@@ -782,17 +785,17 @@ export class PreTrainedModel extends Callable {
782
785
// processors.push(new NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id));
783
786
// }
784
787
785
- // if (generation_config.min_length !== null && generation_config.eos_token_id !== null && generation_config.min_length > 0) {
786
- // processors.push(new MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id));
787
- // }
788
+ if ( generation_config . min_length !== null && generation_config . eos_token_id !== null && generation_config . min_length > 0 ) {
789
+ processors . push ( new MinLengthLogitsProcessor ( generation_config . min_length , generation_config . eos_token_id ) ) ;
790
+ }
788
791
789
- // if (generation_config.min_new_tokens !== null && generation_config.eos_token_id !== null && generation_config.min_new_tokens > 0) {
790
- // processors.push(new MinNewTokensLengthLogitsProcessor(
791
- // input_ids_seq_length,
792
- // generation_config.min_new_tokens,
793
- // generation_config.eos_token_id
794
- // ));
795
- // }
792
+ if ( generation_config . min_new_tokens !== null && generation_config . eos_token_id !== null && generation_config . min_new_tokens > 0 ) {
793
+ processors . push ( new MinNewTokensLengthLogitsProcessor (
794
+ input_ids_seq_length ,
795
+ generation_config . min_new_tokens ,
796
+ generation_config . eos_token_id
797
+ ) ) ;
798
+ }
796
799
797
800
// if (prefix_allowed_tokens_fn !== null) {
798
801
// processors.push(new PrefixConstrainedLogitsProcessor(
@@ -866,7 +869,8 @@ export class PreTrainedModel extends Callable {
866
869
*/
867
870
_get_generation_config ( generation_config ) {
868
871
// Create empty generation config (contains defaults)
869
- let gen_config = new GenerationConfig ( ) ;
872
+ // We pass `this.config` so that if `eos_token_id` or `bos_token_id` exist in the model's config, we will use them
873
+ let gen_config = new GenerationConfig ( this . config ) ;
870
874
871
875
// Apply model's generation config, if it exists
872
876
if ( 'generation_config' in this ) {
@@ -928,7 +932,7 @@ export class PreTrainedModel extends Callable {
928
932
input_ids_seq_length = 0 ;
929
933
930
934
} else {
931
- input_ids_seq_length = inputs instanceof Tensor ? inputs . dims [ 0 ] : inputs . length ;
935
+ input_ids_seq_length = inputs instanceof Tensor ? inputs . dims . at ( - 1 ) : inputs . length ;
932
936
933
937
// decoder-only
934
938
if ( input_ids_seq_length === 0 ) {
@@ -948,6 +952,12 @@ export class PreTrainedModel extends Callable {
948
952
logits_processor
949
953
)
950
954
955
+ /** @type {number[] } */
956
+ let eos_token_ids = generation_config . eos_token_id ;
957
+ if ( eos_token_ids !== null && ! Array . isArray ( eos_token_ids ) ) {
958
+ eos_token_ids = [ eos_token_ids ] ;
959
+ }
960
+
951
961
// TODO implement early_stopping
952
962
// https://huggingface.co/blog/how-to-generate
953
963
@@ -1007,7 +1017,7 @@ export class PreTrainedModel extends Callable {
1007
1017
1008
1018
newBeam . score += logProb ;
1009
1019
1010
- if ( newTokenId === this . config . eos_token_id ) {
1020
+ if ( eos_token_ids && eos_token_ids . includes ( newTokenId ) ) {
1011
1021
newBeam . done = true ;
1012
1022
}
1013
1023
@@ -2476,10 +2486,12 @@ export class VisionEncoderDecoderModel extends PreTrainedModel {
2476
2486
* @param {Object } config The configuration object specifying the hyperparameters and other model settings.
2477
2487
* @param {Object } session The ONNX session containing the encoder model.
2478
2488
* @param {any } decoder_merged_session The ONNX session containing the merged decoder model.
2489
+ * @param {Object } generation_config Configuration object for the generation process.
2479
2490
*/
2480
- constructor ( config , session , decoder_merged_session ) {
2491
+ constructor ( config , session , decoder_merged_session , generation_config ) {
2481
2492
super ( config , session ) ;
2482
2493
this . decoder_merged_session = decoder_merged_session ;
2494
+ this . generation_config = generation_config ;
2483
2495
2484
2496
this . num_layers = this . config . decoder . n_layer ;
2485
2497
this . num_heads = this . config . decoder . n_head ;
@@ -2617,9 +2629,11 @@ export class GPT2PreTrainedModel extends PreTrainedModel {
2617
2629
* Creates a new instance of the `GPT2PreTrainedModel` class.
2618
2630
* @param {Object } config The configuration of the model.
2619
2631
* @param {any } session The ONNX session containing the model weights.
2632
+ * @param {GenerationConfig } generation_config The generation configuration.
2620
2633
*/
2621
- constructor ( config , session ) {
2634
+ constructor ( config , session , generation_config ) {
2622
2635
super ( config , session ) ;
2636
+ this . generation_config = generation_config ;
2623
2637
2624
2638
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2625
2639
this . config . pad_token_id = this . config . eos_token_id
@@ -2649,9 +2663,11 @@ export class GPTNeoPreTrainedModel extends PreTrainedModel {
2649
2663
* Creates a new instance of the `GPTNeoPreTrainedModel` class.
2650
2664
* @param {Object } config The configuration of the model.
2651
2665
* @param {any } session The ONNX session containing the model weights.
2666
+ * @param {GenerationConfig } generation_config The generation configuration.
2652
2667
*/
2653
- constructor ( config , session ) {
2668
+ constructor ( config , session , generation_config ) {
2654
2669
super ( config , session ) ;
2670
+ this . generation_config = generation_config ;
2655
2671
2656
2672
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2657
2673
this . config . pad_token_id = this . config . eos_token_id
@@ -2673,9 +2689,11 @@ export class GPTNeoXPreTrainedModel extends PreTrainedModel {
2673
2689
* Creates a new instance of the `GPTNeoXPreTrainedModel` class.
2674
2690
* @param {Object } config The configuration of the model.
2675
2691
* @param {any } session The ONNX session containing the model weights.
2692
+ * @param {GenerationConfig } generation_config The generation configuration.
2676
2693
*/
2677
- constructor ( config , session ) {
2694
+ constructor ( config , session , generation_config ) {
2678
2695
super ( config , session ) ;
2696
+ this . generation_config = generation_config ;
2679
2697
2680
2698
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2681
2699
this . config . pad_token_id = this . config . eos_token_id
@@ -2698,9 +2716,11 @@ export class GPTJPreTrainedModel extends PreTrainedModel {
2698
2716
* Creates a new instance of the `GPTJPreTrainedModel` class.
2699
2717
* @param {Object } config The configuration of the model.
2700
2718
* @param {any } session The ONNX session containing the model weights.
2719
+ * @param {GenerationConfig } generation_config The generation configuration.
2701
2720
*/
2702
- constructor ( config , session ) {
2721
+ constructor ( config , session , generation_config ) {
2703
2722
super ( config , session ) ;
2723
+ this . generation_config = generation_config ;
2704
2724
2705
2725
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2706
2726
this . config . pad_token_id = this . config . eos_token_id
@@ -2724,9 +2744,11 @@ export class GPTBigCodePreTrainedModel extends PreTrainedModel {
2724
2744
* Creates a new instance of the `GPTBigCodePreTrainedModel` class.
2725
2745
* @param {Object } config The configuration of the model.
2726
2746
* @param {any } session The ONNX session containing the model weights.
2747
+ * @param {GenerationConfig } generation_config The generation configuration.
2727
2748
*/
2728
- constructor ( config , session ) {
2749
+ constructor ( config , session , generation_config ) {
2729
2750
super ( config , session ) ;
2751
+ this . generation_config = generation_config ;
2730
2752
2731
2753
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2732
2754
this . config . pad_token_id = this . config . eos_token_id
@@ -2747,11 +2769,13 @@ export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { }
2747
2769
export class CodeGenPreTrainedModel extends PreTrainedModel {
2748
2770
/**
2749
2771
* Creates a new instance of the `CodeGenPreTrainedModel` class.
2750
- * @param {Object } config The model configuration object.
2751
- * @param {Object } session The ONNX session object.
2752
- */
2753
- constructor ( config , session ) {
2772
+ * @param {Object } config The model configuration object.
2773
+ * @param {Object } session The ONNX session object.
2774
+ * @param {GenerationConfig } generation_config The generation configuration.
2775
+ */
2776
+ constructor ( config , session , generation_config ) {
2754
2777
super ( config , session ) ;
2778
+ this . generation_config = generation_config ;
2755
2779
2756
2780
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2757
2781
this . config . pad_token_id = this . config . eos_token_id
@@ -2785,11 +2809,13 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { }
2785
2809
export class LlamaPreTrainedModel extends PreTrainedModel {
2786
2810
/**
2787
2811
* Creates a new instance of the `LlamaPreTrainedModel` class.
2788
- * @param {Object } config The model configuration object.
2789
- * @param {Object } session The ONNX session object.
2790
- */
2791
- constructor ( config , session ) {
2812
+ * @param {Object } config The model configuration object.
2813
+ * @param {Object } session The ONNX session object.
2814
+ * @param {GenerationConfig } generation_config The generation configuration.
2815
+ */
2816
+ constructor ( config , session , generation_config ) {
2792
2817
super ( config , session ) ;
2818
+ this . generation_config = generation_config ;
2793
2819
2794
2820
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2795
2821
this . config . pad_token_id = this . config . eos_token_id
@@ -2817,9 +2843,11 @@ export class BloomPreTrainedModel extends PreTrainedModel {
2817
2843
* Creates a new instance of the `BloomPreTrainedModel` class.
2818
2844
* @param {Object } config The configuration of the model.
2819
2845
* @param {any } session The ONNX session containing the model weights.
2846
+ * @param {GenerationConfig } generation_config The generation configuration.
2820
2847
*/
2821
- constructor ( config , session ) {
2848
+ constructor ( config , session , generation_config ) {
2822
2849
super ( config , session ) ;
2850
+ this . generation_config = generation_config ;
2823
2851
2824
2852
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2825
2853
this . config . pad_token_id = this . config . eos_token_id
@@ -2848,9 +2876,11 @@ export class MptPreTrainedModel extends PreTrainedModel {
2848
2876
* Creates a new instance of the `MptPreTrainedModel` class.
2849
2877
* @param {Object } config The model configuration object.
2850
2878
* @param {Object } session The ONNX session object.
2879
+ * @param {GenerationConfig } generation_config The generation configuration.
2851
2880
*/
2852
- constructor ( config , session ) {
2881
+ constructor ( config , session , generation_config ) {
2853
2882
super ( config , session ) ;
2883
+ this . generation_config = generation_config ;
2854
2884
2855
2885
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2856
2886
this . config . pad_token_id = this . config . eos_token_id
@@ -2880,9 +2910,11 @@ export class OPTPreTrainedModel extends PreTrainedModel {
2880
2910
* Creates a new instance of the `OPTPreTrainedModel` class.
2881
2911
* @param {Object } config The model configuration object.
2882
2912
* @param {Object } session The ONNX session object.
2913
+ * @param {GenerationConfig } generation_config The generation configuration.
2883
2914
*/
2884
- constructor ( config , session ) {
2915
+ constructor ( config , session , generation_config ) {
2885
2916
super ( config , session ) ;
2917
+ this . generation_config = generation_config ;
2886
2918
2887
2919
// config doesn't contain pad_token_id, so we assume it is the eos_token_id
2888
2920
this . config . pad_token_id = this . config . eos_token_id
0 commit comments