@@ -279,14 +279,10 @@ def __init__(
279279 self .default_sample_size = 128
280280 self .tokenizer_max_length = 512
281281
282- def _encode_prompt ( self , prompt , image , num_images_per_prompt ):
283-
282+ def _encode_prompt ( self , prompt , image ):
284283 raw_vl_input = self .image_processor_vl (images = image ,return_tensors = "pt" )
285284 pixel_values = raw_vl_input ['pixel_values' ]
286285 image_grid_thw = raw_vl_input ['image_grid_thw' ]
287-
288- prompt = [prompt ] if isinstance (prompt , str ) else prompt
289- batch_size = len (prompt )
290286 all_tokens = []
291287 for clean_prompt_sub , matched in split_quotation (prompt [0 ]):
292288 if matched :
@@ -348,25 +344,25 @@ def _encode_prompt( self, prompt, image, num_images_per_prompt ):
348344 prompt_embeds = text_output .hidden_states [- 1 ].detach ()
349345 prompt_embeds = prompt_embeds [:,self .prompt_template_encode_start_idx : - self .prompt_template_encode_end_idx ,:]
350346
351- _ , seq_len , _ = prompt_embeds .shape
352-
353- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
354- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
355- prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
356-
357347 return prompt_embeds
358348
359349 @torch .inference_mode ()
360350 def encode_prompt (self ,
361351 prompt : List [str ] = None ,
362352 image : Optional [torch .Tensor ] = None ,
363353 num_images_per_prompt : Optional [int ] = 1 ,
364- prompt_embeds : Optional [torch .Tensor ] = None ,):
365-
354+ prompt_embeds : Optional [torch .Tensor ] = None ):
355+ prompt = [prompt ] if isinstance (prompt , str ) else prompt
356+ batch_size = len (prompt )
366357 # If prompt_embeds is provided and prompt is None, skip encoding
367358 if prompt_embeds is None :
368- prompt_embeds = self ._encode_prompt ( prompt , image , num_images_per_prompt )
359+ prompt_embeds = self ._encode_prompt ( prompt , image )
369360
361+ _ , seq_len , _ = prompt_embeds .shape
362+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
363+ prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
364+ prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
365+
370366 text_ids = prepare_pos_ids (modality_id = 0 ,
371367 type = 'text' ,
372368 start = (0 , 0 ),
0 commit comments