Skip to content

Commit c7b2d28

Browse files
Update modeling_visual_language.py
1 parent a2c7350 commit c7b2d28

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3908,18 +3908,28 @@ def get_multimodal_embeddings(
39083908
deepstack_visual_embeds = deepstack_video_embeds
39093909

39103910
if position_ids is None:
3911+
attention_mask_tensor = (
3912+
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
3913+
)
3914+
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
3915+
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
3916+
# Only apply conversion for floating point tensors (inverted masks)
3917+
if attention_mask_tensor.dtype.is_floating_point:
3918+
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
3919+
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
39113920

39123921
# Calculate RoPE index once per generation in the pre-fill stage only.
39133922
# When compiling, we can't check tensor values thus we check only input length
39143923
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
39153924
# models currently cannot do asssisted decoding
3916-
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
3917-
# calculate RoPE index once per generation in the pre-fill stage only
3918-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
3919-
position_ids, rope_deltas = self.get_rope_index(
3920-
input_ids, image_grid_thw, video_grid_thw, attention_mask
3921-
)
3922-
self.rope_deltas = rope_deltas
3925+
if self.rope_deltas is None:
3926+
position_ids, rope_deltas = self.get_rope_index(
3927+
input_ids,
3928+
image_grid_thw,
3929+
video_grid_thw,
3930+
attention_mask=attention_mask_tensor,
3931+
)
3932+
self.rope_deltas = rope_deltas
39233933
# then use the prev pre-calculated rope-deltas to get the correct position ids
39243934
else:
39253935
batch_size, seq_length, _ = inputs_embeds.shape

0 commit comments

Comments
 (0)