@@ -174,19 +174,13 @@ def set_decoder(self, decoder):
174
174
def get_decoder (self ):
175
175
return self .language_model
176
176
177
- def get_image_features (
178
- self ,
179
- pixel_values : torch .FloatTensor ,
180
- image_num_patches : torch .Tensor ,
181
- ):
177
+ def get_image_features (self , pixel_values : torch .FloatTensor ):
182
178
"""
183
179
Obtains image last hidden states from the vision tower and apply multimodal projection.
184
180
185
181
Args:
186
182
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
187
183
The tensors corresponding to the input images.
188
- image_num_patches (`torch.Tensor` of shape `(num_images)`)
189
- Number of patches for each image.
190
184
Returns:
191
185
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
192
186
and are of shape `(num_patches, image_length, embed_dim)`).
@@ -227,7 +221,6 @@ def forward(
227
221
self ,
228
222
input_ids : torch .LongTensor = None ,
229
223
pixel_values : torch .FloatTensor = None ,
230
- image_num_patches : Optional [torch .Tensor ] = None ,
231
224
attention_mask : Optional [torch .Tensor ] = None ,
232
225
position_ids : Optional [torch .LongTensor ] = None ,
233
226
past_key_values : Optional [Cache ] = None ,
@@ -236,18 +229,14 @@ def forward(
236
229
cache_position : Optional [torch .LongTensor ] = None ,
237
230
** kwargs : Unpack [FlashAttentionKwargs ],
238
231
) -> Union [tuple , Cohere2VisionModelOutputWithPast ]:
239
- r"""
240
- image_num_patches (`torch.Tensor` of shape `(num_images,)`):
241
- Number of patches per input image.
242
- """
243
232
if (input_ids is None ) ^ (inputs_embeds is not None ):
244
233
raise ValueError ("You must specify exactly one of input_ids or inputs_embeds" )
245
234
246
235
if inputs_embeds is None :
247
236
inputs_embeds = self .get_input_embeddings ()(input_ids )
248
237
249
238
if pixel_values is not None :
250
- image_features = self .get_image_features (pixel_values , image_num_patches = image_num_patches )
239
+ image_features = self .get_image_features (pixel_values )
251
240
image_features = image_features .to (inputs_embeds .device , inputs_embeds .dtype )
252
241
special_image_mask = self .get_placeholder_mask (
253
242
input_ids , inputs_embeds = inputs_embeds , image_features = image_features
@@ -303,15 +292,8 @@ def set_decoder(self, decoder):
303
292
def get_decoder (self ):
304
293
return self .model .get_decoder ()
305
294
306
- def get_image_features (
307
- self ,
308
- pixel_values : torch .FloatTensor ,
309
- image_num_patches : torch .Tensor ,
310
- ):
311
- return self .model .get_image_features (
312
- pixel_values = pixel_values ,
313
- image_num_patches = image_num_patches ,
314
- )
295
+ def get_image_features (self , pixel_values : torch .FloatTensor ):
296
+ return self .model .get_image_features (pixel_values = pixel_values )
315
297
316
298
# Make modules available throught conditional class for BC
317
299
@property
@@ -332,7 +314,6 @@ def forward(
332
314
self ,
333
315
input_ids : Optional [torch .LongTensor ] = None ,
334
316
pixel_values : Optional [torch .FloatTensor ] = None ,
335
- image_num_patches : Optional [torch .Tensor ] = None ,
336
317
attention_mask : Optional [torch .Tensor ] = None ,
337
318
position_ids : Optional [torch .LongTensor ] = None ,
338
319
past_key_values : Optional [Cache ] = None ,
@@ -345,8 +326,6 @@ def forward(
345
326
** kwargs : Unpack [TransformersKwargs ],
346
327
) -> Union [tuple , Cohere2VisionCausalLMOutputWithPast ]:
347
328
r"""
348
- image_num_patches (`torch.Tensor` of shape `(num_images,)`):
349
- Number of patches per input image.
350
329
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
351
330
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
352
331
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
@@ -384,7 +363,6 @@ def forward(
384
363
outputs = self .model (
385
364
input_ids = input_ids ,
386
365
pixel_values = pixel_values ,
387
- image_num_patches = image_num_patches ,
388
366
attention_mask = attention_mask ,
389
367
position_ids = position_ids ,
390
368
past_key_values = past_key_values ,
0 commit comments