@@ -3908,18 +3908,28 @@ def get_multimodal_embeddings(
39083908 deepstack_visual_embeds = deepstack_video_embeds
39093909
39103910 if position_ids is None :
3911+ attention_mask_tensor = (
3912+ attention_mask if not isinstance (attention_mask , dict ) else attention_mask ["full_attention" ]
3913+ )
3914+ if attention_mask_tensor is not None and attention_mask_tensor .ndim == 4 :
3915+ attention_mask_tensor = torch .diagonal (attention_mask_tensor [:, 0 ], dim1 = 1 , dim2 = 2 )
3916+ # Only apply conversion for floating point tensors (inverted masks)
3917+ if attention_mask_tensor .dtype .is_floating_point :
3918+ attention_mask_tensor = attention_mask_tensor / torch .finfo (attention_mask_tensor .dtype ).min
3919+ attention_mask_tensor = (1.0 - attention_mask_tensor ).int ()
39113920
39123921 # Calculate RoPE index once per generation in the pre-fill stage only.
39133922 # When compiling, we can't check tensor values thus we check only input length
39143923 # It is safe to assume that `length!=1` means we're in pre-fill because compiled
39153924 # models currently cannot do asssisted decoding
3916- if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask .ndim == 2 ):
3917- # calculate RoPE index once per generation in the pre-fill stage only
3918- if (cache_position is not None and cache_position [0 ] == 0 ) or self .rope_deltas is None :
3919- position_ids , rope_deltas = self .get_rope_index (
3920- input_ids , image_grid_thw , video_grid_thw , attention_mask
3921- )
3922- self .rope_deltas = rope_deltas
3925+ if self .rope_deltas is None :
3926+ position_ids , rope_deltas = self .get_rope_index (
3927+ input_ids ,
3928+ image_grid_thw ,
3929+ video_grid_thw ,
3930+ attention_mask = attention_mask_tensor ,
3931+ )
3932+ self .rope_deltas = rope_deltas
39233933 # then use the prev pre-calculated rope-deltas to get the correct position ids
39243934 else :
39253935 batch_size , seq_length , _ = inputs_embeds .shape
0 commit comments