2424 CLIPTokenizer ,
2525 LlamaTokenizerFast ,
2626 LlavaForConditionalGeneration ,
27+ LlavaProcessor ,
2728)
2829
2930from ...callbacks import MultiPipelineCallbacks , PipelineCallback
100101}
101102
102103
104+ def _expand_input_ids_with_image_tokens (
105+ text_input_ids ,
106+ prompt_attention_mask ,
107+ max_sequence_length ,
108+ image_token_index ,
109+ image_emb_len ,
110+ image_emb_start ,
111+ image_emb_end ,
112+ pad_token_id ,
113+ ):
114+ special_image_token_mask = text_input_ids == image_token_index
115+ num_special_image_tokens = torch .sum (special_image_token_mask , dim = - 1 )
116+ batch_indices , non_image_indices = torch .where (text_input_ids != image_token_index )
117+
118+ max_expanded_length = max_sequence_length + (num_special_image_tokens .max () * (image_emb_len - 1 ))
119+ new_token_positions = torch .cumsum ((special_image_token_mask * (image_emb_len - 1 ) + 1 ), - 1 ) - 1
120+ text_to_overwrite = new_token_positions [batch_indices , non_image_indices ]
121+
122+ expanded_input_ids = torch .full (
123+ (text_input_ids .shape [0 ], max_expanded_length ),
124+ pad_token_id ,
125+ dtype = text_input_ids .dtype ,
126+ device = text_input_ids .device ,
127+ )
128+ expanded_input_ids [batch_indices , text_to_overwrite ] = text_input_ids [batch_indices , non_image_indices ]
129+ expanded_input_ids [batch_indices , image_emb_start :image_emb_end ] = image_token_index
130+
131+ expanded_attention_mask = torch .zeros (
132+ (text_input_ids .shape [0 ], max_expanded_length ),
133+ dtype = prompt_attention_mask .dtype ,
134+ device = prompt_attention_mask .device ,
135+ )
136+ attn_batch_indices , attention_indices = torch .where (expanded_input_ids != pad_token_id )
137+ expanded_attention_mask [attn_batch_indices , attention_indices ] = 1.0
138+ expanded_attention_mask = expanded_attention_mask .to (prompt_attention_mask .dtype )
139+ position_ids = (expanded_attention_mask .cumsum (- 1 ) - 1 ).masked_fill_ ((expanded_attention_mask == 0 ), 1 )
140+
141+ return {
142+ "input_ids" : expanded_input_ids ,
143+ "attention_mask" : expanded_attention_mask ,
144+ "position_ids" : position_ids ,
145+ }
146+
147+
103148# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104149def retrieve_timesteps (
105150 scheduler ,
@@ -231,6 +276,13 @@ def __init__(
231276 self .vae_scale_factor_temporal = self .vae .temporal_compression_ratio if getattr (self , "vae" , None ) else 4
232277 self .vae_scale_factor_spatial = self .vae .spatial_compression_ratio if getattr (self , "vae" , None ) else 8
233278 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
279+ self .llava_processor = LlavaProcessor (
280+ self .image_processor ,
281+ self .tokenizer ,
282+ patch_size = self .text_encoder .config .vision_config .patch_size ,
283+ vision_feature_select_strategy = self .text_encoder .config .vision_feature_select_strategy ,
284+ num_additional_image_tokens = 1 ,
285+ )
234286
235287 def _get_llama_prompt_embeds (
236288 self ,
@@ -251,6 +303,12 @@ def _get_llama_prompt_embeds(
251303 prompt = [prompt_template ["template" ].format (p ) for p in prompt ]
252304
253305 crop_start = prompt_template .get ("crop_start" , None )
306+
307+ image_emb_len = prompt_template .get ("image_emb_len" , 576 )
308+ image_emb_start = prompt_template .get ("image_emb_start" , 5 )
309+ image_emb_end = prompt_template .get ("image_emb_end" , 581 )
310+ double_return_token_id = prompt_template .get ("double_return_token_id" , 271 )
311+
254312 if crop_start is None :
255313 prompt_template_input = self .tokenizer (
256314 prompt_template ["template" ],
@@ -280,19 +338,25 @@ def _get_llama_prompt_embeds(
280338
281339 image_embeds = self .image_processor (image , return_tensors = "pt" ).pixel_values .to (device )
282340
341+ image_token_index = self .text_encoder .config .image_token_index
342+ pad_token_id = self .text_encoder .config .pad_token_id
343+ expanded_inputs = _expand_input_ids_with_image_tokens (
344+ text_input_ids ,
345+ prompt_attention_mask ,
346+ max_sequence_length ,
347+ image_token_index ,
348+ image_emb_len ,
349+ image_emb_start ,
350+ image_emb_end ,
351+ pad_token_id ,
352+ )
283353 prompt_embeds = self .text_encoder (
284- input_ids = text_input_ids ,
285- attention_mask = prompt_attention_mask ,
286- pixel_values = image_embeds ,
354+ ** expanded_inputs ,
355+ pixel_value = image_embeds ,
287356 output_hidden_states = True ,
288357 ).hidden_states [- (num_hidden_layers_to_skip + 1 )]
289358 prompt_embeds = prompt_embeds .to (dtype = dtype )
290359
291- image_emb_len = prompt_template .get ("image_emb_len" , 576 )
292- image_emb_start = prompt_template .get ("image_emb_start" , 5 )
293- image_emb_end = prompt_template .get ("image_emb_end" , 581 )
294- double_return_token_id = prompt_template .get ("double_return_token_id" , 271 )
295-
296360 if crop_start is not None and crop_start > 0 :
297361 text_crop_start = crop_start - 1 + image_emb_len
298362 batch_indices , last_double_return_token_indices = torch .where (text_input_ids == double_return_token_id )
0 commit comments