5353
5454
5555if TYPE_CHECKING :
56- from PIL import Image
56+ from PIL . Image import Image
5757
5858
5959logger = logging .getLogger (__name__ )
@@ -2100,6 +2100,8 @@ def __init__(
21002100 quantization_config = quantization_config ,
21012101 ** kwargs ,
21022102 )
2103+ self .rope_deltas = None # cache rope_deltas here
2104+
21032105 if is_transformers_version (">=" , "4.45.0" ):
21042106 from transformers .models .qwen2_vl .modeling_qwen2_vl import (
21052107 Qwen2VLForConditionalGeneration ,
@@ -2197,6 +2199,7 @@ def get_multimodal_embeddings(
21972199 pixel_values_videos = None ,
21982200 image_grid_thw = None ,
21992201 video_grid_thw = None ,
2202+ cache_position = None ,
22002203 ** kwargs ,
22012204 ):
22022205 inputs_embeds = torch .from_numpy (self .get_text_embeddings (input_ids ))
@@ -2209,6 +2212,26 @@ def get_multimodal_embeddings(
22092212 video_embeds = torch .from_numpy (self .get_vision_embeddings (pixel_values_videos , video_grid_thw ))
22102213 video_mask = input_ids == self .config .video_token_id
22112214 inputs_embeds [video_mask ] = video_embeds
2215+
2216+ # if we get 4D attention mask we cannot calculate rope deltas anymore.
2217+ if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask .ndim == 2 ):
2218+ # calculate RoPE index once per generation in the pre-fill stage only
2219+ if (cache_position is not None and cache_position [0 ] == 0 ) or self .rope_deltas is None :
2220+ position_ids , rope_deltas = self .get_rope_index (
2221+ input_ids , image_grid_thw , video_grid_thw , attention_mask
2222+ )
2223+ self .rope_deltas = rope_deltas
2224+ # then use the prev pre-calculated rope-deltas to get the correct position ids
2225+ else :
2226+ batch_size , seq_length , _ = inputs_embeds .shape
2227+ delta = cache_position [0 ] + self .rope_deltas if cache_position is not None else 0
2228+ position_ids = torch .arange (seq_length , device = inputs_embeds .device )
2229+ position_ids = position_ids .view (1 , - 1 ).expand (batch_size , - 1 )
2230+ if cache_position is not None : # otherwise `deltas` is an int `0`
2231+ delta = delta .repeat_interleave (batch_size // delta .shape [0 ], dim = 0 )
2232+ position_ids = position_ids .add (delta )
2233+ position_ids = position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 )
2234+
22122235 return inputs_embeds , attention_mask , position_ids
22132236
22142237 def forward (
0 commit comments