@@ -261,40 +261,41 @@ def _get_t5_prompt_embeds(
261261 prompt_embeds = prompt_embeds .repeat (1 , num_videos_per_prompt , 1 )
262262 prompt_embeds = prompt_embeds .view (batch_size * num_videos_per_prompt , seq_len , - 1 )
263263
264- return prompt_embeds , prompt_attention_mask
264+ return prompt_embeds
265265
266266 def encode_prompt (
267267 self ,
268268 prompt : Union [str , List [str ]],
269+ negative_prompt : Optional [Union [str , List [str ]]] = None ,
269270 device : Optional [torch .device ] = None ,
270271 num_videos_per_prompt : int = 1 ,
271272 prompt_embeds : Optional [torch .FloatTensor ] = None ,
273+ negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
272274 max_sequence_length : int = 512 ,
275+ do_classifier_free_guidance = True ,
273276 lora_scale : Optional [float ] = None ,
274277 ):
275278 r"""
276279
277280 Args:
278281 prompt (`str` or `List[str]`, *optional*):
279282 prompt to be encoded
280- prompt_2 (`str` or `List[str]`, *optional*):
281- The prompt or prompts to be sent to the `tokenizer` and `text_encoder`. If not defined, `prompt` is
282- used in all text-encoders
283283 device: (`torch.device`):
284284 torch device
285285 num_videos_per_prompt (`int`):
286286 number of images that should be generated per prompt
287287 prompt_embeds (`torch.FloatTensor`, *optional*):
288288 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
289289 provided, text embeddings will be generated from `prompt` input argument.
290- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
291- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
292- If not provided, pooled text embeddings will be generated from `prompt` input argument.
293290 lora_scale (`float`, *optional*):
294291 A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
295292 """
296293 device = device or self ._execution_device
297294 prompt = [prompt ] if isinstance (prompt , str ) else prompt
295+ if prompt is not None :
296+ batch_size = len (prompt )
297+ else :
298+ batch_size = prompt_embeds .shape [0 ]
298299
299300 if prompt_embeds is None :
300301 prompt_embeds = self ._get_t5_prompt_embeds (
@@ -307,8 +308,32 @@ def encode_prompt(
307308 dtype = self .text_encoder .dtype if self .text_encoder is not None else self .transformer .dtype
308309
309310 # TODO: Add negative prompts back
311+ if do_classifier_free_guidance and negative_prompt_embeds is None :
312+ negative_prompt = negative_prompt or ""
313+ # normalize str to list
314+ negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
315+ )
310316
311- return prompt_embeds
317+ if prompt is not None and type (prompt ) is not type (negative_prompt ):
318+ raise TypeError (
319+ f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
320+ f" { type (prompt )} ."
321+ )
322+ elif batch_size != len (negative_prompt ):
323+ raise ValueError (
324+ f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
325+ f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
326+ " the batch size of `prompt`."
327+ )
328+
329+ negative_prompt_embeds = self ._get_t5_prompt_embeds (
330+ prompt = negative_prompt ,
331+ num_videos_per_prompt = num_videos_per_prompt ,
332+ max_sequence_length = max_sequence_length ,
333+ device = device ,
334+ )
335+
336+ return prompt_embeds , negative_prompt_embeds
312337
313338 def check_inputs (
314339 self ,
@@ -541,7 +566,7 @@ def __call__(
541566 lora_scale = (
542567 self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
543568 )
544- (prompt_embeds ) = self .encode_prompt (
569+ (prompt_embeds , negative_prompt_embeds ) = self .encode_prompt (
545570 prompt = prompt ,
546571 prompt_embeds = prompt_embeds ,
547572 device = device ,
@@ -589,12 +614,8 @@ def __call__(
589614 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
590615 self ._num_timesteps = len (timesteps )
591616
592- # handle guidance
593- if self .transformer .config .guidance_embeds :
594- guidance = torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
595- guidance = guidance .expand (latents .shape [0 ])
596- else :
597- guidance = None
617+ if self .do_classifier_free_guidance :
618+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
598619
599620 # 6. Denoising loop
600621 with self .progress_bar (total = num_inference_steps ) as progress_bar :
0 commit comments