@@ -887,8 +887,26 @@ function createPositionIds(model_inputs, past_key_values = null, start_index = 0
887887}
888888
889889function decoder_prepare_inputs_for_generation ( self , input_ids , model_inputs , generation_config ) {
890+ const past_length = model_inputs . past_key_values
891+ ? Object . values ( model_inputs . past_key_values ) [ 0 ] . dims . at ( - 2 )
892+ : 0 ;
893+
894+ if ( ! model_inputs . attention_mask ) {
895+ // If the attention mask is not provided, we attempt to infer based on provided inputs
896+ let dims ;
897+ for ( const key of [ 'input_ids' , 'inputs_embeds' , 'position_ids' ] ) {
898+ if ( model_inputs [ key ] ) {
899+ dims = model_inputs [ key ] . dims ;
900+ break ;
901+ }
902+ }
903+ if ( ! dims ) {
904+ throw new Error ( "attention_mask is not provided, and unable to infer its shape from model inputs." ) ;
905+ }
906+ model_inputs . attention_mask = ones ( [ dims [ 0 ] , past_length + dims [ 1 ] ] ) ;
907+ }
908+
890909 if ( model_inputs . past_key_values ) {
891- const past_length = Object . values ( model_inputs . past_key_values ) [ 0 ] . dims . at ( - 2 ) ;
892910 const { input_ids, attention_mask } = model_inputs ;
893911
894912 // Keep only the unprocessed tokens:
@@ -909,24 +927,7 @@ function decoder_prepare_inputs_for_generation(self, input_ids, model_inputs, ge
909927 }
910928 // 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
911929 else {
912- if (
913- // NOTE: Only used by VLMs (!= so that null matches undefined)
914- self . config . image_token_index != null &&
915- // Equivalent to `self.config.image_token_index in input_ids` (== so that int matches bigint)
916- input_ids . data . some ( x => x == self . config . image_token_index )
917- ) {
918- // TODO: Support multiple image tokens
919- const num_image_tokens = self . config . num_image_tokens ;
920- if ( ! num_image_tokens ) {
921- throw new Error ( '`num_image_tokens` is missing in the model configuration.' ) ;
922- }
923-
924- const num_new_tokens = input_ids . dims [ 1 ] - ( past_length - num_image_tokens ) ;
925- model_inputs . input_ids = input_ids . slice ( null , [ - num_new_tokens , null ] ) ;
926930
927- // TODO: The attention mask should be formed from the attention mask passed in model_inputs
928- model_inputs . attention_mask = ones ( [ 1 , past_length + num_new_tokens ] ) ;
929- }
930931 }
931932 }
932933
@@ -2016,17 +2017,7 @@ export class PreTrainedModel extends Callable {
20162017
20172018 async encode_image ( { pixel_values } ) {
20182019 // image_inputs === { pixel_values }
2019- const features = ( await sessionRun ( this . sessions [ 'vision_encoder' ] , { pixel_values } ) ) . image_features ;
2020- // @ts -expect-error TS2339
2021- if ( ! this . config . num_image_tokens ) {
2022- console . warn (
2023- 'The number of image tokens was not set in the model configuration. ' +
2024- `Setting it to the number of features detected by the vision encoder (${ features . dims [ 1 ] } ).`
2025- )
2026- // @ts -expect-error TS2339
2027- this . config . num_image_tokens = features . dims [ 1 ] ;
2028- }
2029- return features ;
2020+ return ( await sessionRun ( this . sessions [ 'vision_encoder' ] , { pixel_values } ) ) . image_features ;
20302021 }
20312022
20322023 async encode_text ( { input_ids } ) {
@@ -3640,65 +3631,16 @@ export class LlavaPreTrainedModel extends PreTrainedModel {
36403631 * The LLAVA model which consists of a vision backbone and a language model.
36413632 */
36423633export class LlavaForConditionalGeneration extends LlavaPreTrainedModel {
3634+ _merge_input_ids_with_image_features ( kwargs ) {
3635+ const vision_hidden_size = kwargs . image_features . dims . at ( - 1 ) ;
3636+ const reshaped_image_hidden_states = kwargs . image_features . view ( - 1 , vision_hidden_size ) ;
36433637
3644- _merge_input_ids_with_image_features ( {
3645- inputs_embeds,
3646- image_features,
3647- input_ids,
3648- attention_mask,
3649- } ) {
3650-
3651- // @ts -expect-error TS2339
3652- const image_token_index = this . config . image_token_index ;
3653-
3654- const idsList = input_ids . tolist ( ) ;
3655-
3656- // NOTE: we use .findIndex instead of .indexOf to perform weak comparison (==) between BigInt and Number
3657- const indexOfImage = idsList . map ( x => x . findIndex ( x => x == image_token_index ) ) ;
3658-
3659- const noImages = indexOfImage . every ( x => x === - 1 ) ;
3660- const allImages = indexOfImage . every ( x => x !== - 1 ) ;
3661- if ( ! noImages && ! allImages ) {
3662- // Check for padding reasons
3663- throw new Error ( 'Every input should contain either 0 or 1 image token.' ) ;
3664- }
3665-
3666- if ( noImages ) {
3667- return {
3668- inputs_embeds,
3669- attention_mask,
3670- }
3671- }
3672-
3673- const stacked = [ ] ;
3674- const stacked_attention_mask = [ ] ;
3675- for ( let i = 0 ; i < indexOfImage . length ; ++ i ) {
3676- const index = indexOfImage [ i ] ;
3677-
3678- const e = inputs_embeds [ i ] ;
3679- const im = image_features [ i ] ;
3680- const am = attention_mask [ i ] ;
3681- stacked . push (
3682- cat ( [
3683- e . slice ( [ 0 , index ] ) ,
3684- im ,
3685- e . slice ( [ index + 1 , e . dims [ 0 ] ] ) ,
3686- ] , 0 )
3687- ) ;
3688-
3689- stacked_attention_mask . push (
3690- cat ( [
3691- am . slice ( [ 0 , index ] ) ,
3692- ones ( [ im . dims [ 0 ] ] ) ,
3693- am . slice ( [ index + 1 , am . dims [ 0 ] ] )
3694- ] , 0 )
3695- )
3696- }
3697-
3698- return {
3699- inputs_embeds : stack ( stacked , 0 ) ,
3700- attention_mask : stack ( stacked_attention_mask , 0 ) ,
3701- }
3638+ return default_merge_input_ids_with_image_features ( {
3639+ // @ts -ignore
3640+ image_token_id : this . config . image_token_index ,
3641+ ...kwargs ,
3642+ image_features : reshaped_image_hidden_states ,
3643+ } )
37023644 }
37033645}
37043646//////////////////////////////////////////////////
@@ -3839,6 +3781,20 @@ export class PaliGemmaForConditionalGeneration extends PaliGemmaPreTrainedModel
38393781 }
38403782}
38413783
3784+ export class LlavaQwen2ForCausalLM extends LlavaPreTrainedModel {
3785+ _merge_input_ids_with_image_features ( kwargs ) {
3786+ const vision_hidden_size = kwargs . image_features . dims . at ( - 1 ) ;
3787+ const reshaped_image_hidden_states = kwargs . image_features . view ( - 1 , vision_hidden_size ) ;
3788+
3789+ return default_merge_input_ids_with_image_features ( {
3790+ // @ts -ignore
3791+ image_token_id : this . config . image_token_index ,
3792+ ...kwargs ,
3793+ image_features : reshaped_image_hidden_states ,
3794+ } )
3795+ }
3796+ }
3797+
38423798//////////////////////////////////////////////////
38433799// Idefics3 Models
38443800export class Idefics3PreTrainedModel extends PreTrainedModel {
@@ -7842,6 +7798,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([
78427798 [ 'idefics3' , [ 'Idefics3ForConditionalGeneration' , Idefics3ForConditionalGeneration ] ] ,
78437799 [ 'smolvlm' , [ 'SmolVLMForConditionalGeneration' , SmolVLMForConditionalGeneration ] ] ,
78447800 [ 'paligemma' , [ 'PaliGemmaForConditionalGeneration' , PaliGemmaForConditionalGeneration ] ] ,
7801+ [ 'llava_qwen2' , [ 'LlavaQwen2ForCausalLM' , LlavaQwen2ForCausalLM ] ] ,
78457802] ) ;
78467803
78477804const MODEL_FOR_AUDIO_TEXT_TO_TEXT_MAPPING_NAMES = new Map ( [
0 commit comments