Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 4 additions & 26 deletions src/transformers/models/cohere2_vision/modeling_cohere2_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,13 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.language_model

def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_num_patches: torch.Tensor,
):
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.

Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_num_patches (`torch.Tensor` of shape `(num_images)`)
Number of patches for each image.
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
Expand Down Expand Up @@ -227,7 +221,6 @@ def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_num_patches: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
Expand All @@ -236,18 +229,14 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, Cohere2VisionModelOutputWithPast]:
r"""
image_num_patches (`torch.Tensor` of shape `(num_images,)`):
Number of patches per input image.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

if pixel_values is not None:
image_features = self.get_image_features(pixel_values, image_num_patches=image_num_patches)
image_features = self.get_image_features(pixel_values)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
Expand Down Expand Up @@ -303,15 +292,8 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model.get_decoder()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_num_patches: torch.Tensor,
):
return self.model.get_image_features(
pixel_values=pixel_values,
image_num_patches=image_num_patches,
)
def get_image_features(self, pixel_values: torch.FloatTensor):
return self.model.get_image_features(pixel_values=pixel_values)

# Make modules available throught conditional class for BC
@property
Expand All @@ -332,7 +314,6 @@ def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_num_patches: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
Expand All @@ -345,8 +326,6 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Cohere2VisionCausalLMOutputWithPast]:
r"""
image_num_patches (`torch.Tensor` of shape `(num_images,)`):
Number of patches per input image.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
Expand Down Expand Up @@ -384,7 +363,6 @@ def forward(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_num_patches=image_num_patches,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
30 changes: 4 additions & 26 deletions src/transformers/models/cohere2_vision/modular_cohere2_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,13 @@ class Cohere2VisionCausalLMOutputWithPast(AyaVisionCausalLMOutputWithPast):
class Cohere2VisionModel(AyaVisionModel):
_checkpoint_conversion_mapping = {}

def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_num_patches: torch.Tensor,
):
def get_image_features(self, pixel_values: torch.FloatTensor):
"""
Obtains image last hidden states from the vision tower and apply multimodal projection.

Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`)
The tensors corresponding to the input images.
image_num_patches (`torch.Tensor` of shape `(num_images)`)
Number of patches for each image.
Returns:
image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches
and are of shape `(num_patches, image_length, embed_dim)`).
Expand All @@ -123,7 +117,6 @@ def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
image_num_patches: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
Expand All @@ -132,18 +125,14 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Union[tuple, Cohere2VisionModelOutputWithPast]:
r"""
image_num_patches (`torch.Tensor` of shape `(num_images,)`):
Number of patches per input image.
"""
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

if pixel_values is not None:
image_features = self.get_image_features(pixel_values, image_num_patches=image_num_patches)
image_features = self.get_image_features(pixel_values)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
Expand Down Expand Up @@ -172,23 +161,15 @@ def forward(
class Cohere2VisionForConditionalGeneration(AyaVisionForConditionalGeneration):
_checkpoint_conversion_mapping = {}

def get_image_features(
self,
pixel_values: torch.FloatTensor,
image_num_patches: torch.Tensor,
):
return self.model.get_image_features(
pixel_values=pixel_values,
image_num_patches=image_num_patches,
)
def get_image_features(self, pixel_values: torch.FloatTensor):
return self.model.get_image_features(pixel_values=pixel_values)

@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_num_patches: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
Expand All @@ -201,8 +182,6 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Cohere2VisionCausalLMOutputWithPast]:
r"""
image_num_patches (`torch.Tensor` of shape `(num_images,)`):
Number of patches per input image.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
Expand Down Expand Up @@ -240,7 +219,6 @@ def forward(
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
image_num_patches=image_num_patches,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
6 changes: 2 additions & 4 deletions tests/models/cohere2_vision/test_modeling_cohere2_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,12 @@ def get_config(self):
def prepare_config_and_inputs(self):
config = self.get_config()
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
image_num_patches = torch.tensor([1] * self.batch_size).to(torch_device)

return config, pixel_values, image_num_patches
return config, pixel_values

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, image_num_patches = config_and_inputs
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[input_ids == self.image_token_id] = self.pad_token_id
Expand All @@ -136,7 +135,6 @@ def prepare_config_and_inputs_for_common(self):
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"image_num_patches": image_num_patches,
}
return config, inputs_dict

Expand Down