@@ -251,13 +251,6 @@ def encode_prompt(
251251 if device is None :
252252 device = self ._execution_device
253253
254- if prompt is not None and isinstance (prompt , str ):
255- batch_size = 1
256- elif prompt is not None and isinstance (prompt , list ):
257- batch_size = len (prompt )
258- else :
259- batch_size = prompt_embeds .shape [0 ]
260-
261254 # See Section 3.1. of the paper.
262255 max_length = max_sequence_length
263256
@@ -302,12 +295,12 @@ def encode_prompt(
302295 # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
303296 prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
304297 prompt_embeds = prompt_embeds .view (bs_embed * num_videos_per_prompt , seq_len , - 1 )
305- prompt_attention_mask = prompt_attention_mask .view ( bs_embed , - 1 )
306- prompt_attention_mask = prompt_attention_mask .repeat ( num_videos_per_prompt , 1 )
298+ prompt_attention_mask = prompt_attention_mask .repeat ( 1 , num_videos_per_prompt )
299+ prompt_attention_mask = prompt_attention_mask .view ( bs_embed * num_videos_per_prompt , - 1 )
307300
308301 # get unconditional embeddings for classifier free guidance
309302 if do_classifier_free_guidance and negative_prompt_embeds is None :
310- uncond_tokens = [negative_prompt ] * batch_size if isinstance (negative_prompt , str ) else negative_prompt
303+ uncond_tokens = [negative_prompt ] * bs_embed if isinstance (negative_prompt , str ) else negative_prompt
311304 uncond_tokens = self ._text_preprocessing (uncond_tokens , clean_caption = clean_caption )
312305 max_length = prompt_embeds .shape [1 ]
313306 uncond_input = self .tokenizer (
@@ -334,10 +327,10 @@ def encode_prompt(
334327 negative_prompt_embeds = negative_prompt_embeds .to (dtype = dtype , device = device )
335328
336329 negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
337- negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
330+ negative_prompt_embeds = negative_prompt_embeds .view (bs_embed * num_videos_per_prompt , seq_len , - 1 )
338331
339- negative_prompt_attention_mask = negative_prompt_attention_mask .view ( bs_embed , - 1 )
340- negative_prompt_attention_mask = negative_prompt_attention_mask .repeat ( num_videos_per_prompt , 1 )
332+ negative_prompt_attention_mask = negative_prompt_attention_mask .repeat ( 1 , num_videos_per_prompt )
333+ negative_prompt_attention_mask = negative_prompt_attention_mask .view ( bs_embed * num_videos_per_prompt , - 1 )
341334 else :
342335 negative_prompt_embeds = None
343336 negative_prompt_attention_mask = None
0 commit comments