@@ -101,6 +101,49 @@ def __init__(self, config):
101101 def get_model (self ):
102102 return self .model
103103
104+ def prepare_inputs_for_generation (
105+ self ,
106+ input_ids ,
107+ past_key_values = None ,
108+ attention_mask = None ,
109+ inputs_embeds = None ,
110+ cache_position = None ,
111+ position_ids = None ,
112+ use_cache = True ,
113+ pixel_values = None ,
114+ pixel_values_videos = None ,
115+ image_grid_thw = None ,
116+ video_grid_thw = None ,
117+ second_per_grid_ts = None ,
118+ ** kwargs ,
119+ ):
120+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
121+
122+ model_inputs = super ().prepare_inputs_for_generation (
123+ input_ids ,
124+ past_key_values = past_key_values ,
125+ attention_mask = attention_mask ,
126+ inputs_embeds = inputs_embeds ,
127+ cache_position = cache_position ,
128+ position_ids = position_ids ,
129+ pixel_values = pixel_values ,
130+ pixel_values_videos = pixel_values_videos ,
131+ image_grid_thw = image_grid_thw ,
132+ video_grid_thw = video_grid_thw ,
133+ second_per_grid_ts = second_per_grid_ts ,
134+ use_cache = use_cache ,
135+ ** kwargs ,
136+ )
137+ # Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
138+ model_inputs ["position_ids" ] = None
139+
140+ # add for QwenVL kv cache
141+ model_inputs ["pixel_values" ] = pixel_values
142+ model_inputs ["pixel_values_videos" ] = pixel_values_videos
143+
144+ return model_inputs
145+
146+
104147 def forward (
105148 self ,
106149 input_ids : Optional [torch .LongTensor ] = None ,
@@ -121,6 +164,7 @@ def forward(
121164 rope_deltas : Optional [torch .LongTensor ] = None ,
122165 cache_position : Optional [torch .LongTensor ] = None ,
123166 second_per_grid_ts : Optional [torch .Tensor ] = None ,
167+ raw_input_ids : Optional [torch .LongTensor ] = None ,
124168 ) -> Union [Tuple , CausalLMOutputWithPast ]:
125169 r"""
126170 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -169,10 +213,11 @@ def forward(
169213
170214 if inputs_embeds is None :
171215 inputs_embeds = self .model .embed_tokens (input_ids )
172- if pixel_values is not None :
216+ n_image_tokens = (input_ids == self .config .image_token_id ).sum ().item ()
217+ if pixel_values is not None and n_image_tokens > 0 :
173218 pixel_values = pixel_values .type (self .visual .dtype )
174219 image_embeds = self .visual (pixel_values , grid_thw = image_grid_thw )
175- n_image_tokens = ( input_ids == self . config . image_token_id ). sum (). item ()
220+ image_embeds = image_embeds [ - n_image_tokens :]
176221 n_image_features = image_embeds .shape [0 ]
177222 if n_image_tokens != n_image_features :
178223 raise ValueError (
@@ -232,6 +277,22 @@ def forward(
232277 attention_mask ,
233278 )
234279 self .rope_deltas = rope_deltas
280+ elif n_image_tokens > 0 : # using only for kv cache
281+ attention_mask = attention_mask [:, :raw_input_ids .shape [1 ]]
282+ position_ids , rope_deltas = self .get_rope_index (
283+ raw_input_ids ,
284+ image_grid_thw ,
285+ video_grid_thw ,
286+ second_per_grid_ts ,
287+ attention_mask ,
288+ )
289+ delta = (
290+ (cache_position [0 ] + self .rope_deltas ).to (inputs_embeds .device )
291+ if cache_position is not None
292+ else 0
293+ )
294+ position_ids = position_ids [:, :,- input_ids .shape [1 ]:]
295+ self .rope_deltas = rope_deltas
235296 # then use the prev pre-calculated rope-deltas to get the correct position ids
236297 else :
237298 batch_size , seq_length , _ = inputs_embeds .shape
0 commit comments