@@ -345,7 +345,7 @@ def encode_prompt(
345345            )
346346
347347        if  pooled_prompt_embeds  is  None :
348-             if  prompt_2  is  None   and   pooled_prompt_embeds   is   None :
348+             if  prompt_2  is  None :
349349                prompt_2  =  prompt 
350350            pooled_prompt_embeds  =  self ._get_clip_prompt_embeds (
351351                prompt ,
@@ -424,13 +424,16 @@ def prepare_latents(
424424                f" size of { batch_size }  . Make sure the batch size matches the length of the generators." 
425425            )
426426
427+         image  =  image .unsqueeze (2 )  # [B, C, 1, H, W] 
427428        if  isinstance (generator , list ):
428429            image_latents  =  [
429430                retrieve_latents (self .vae .encode (image [i ].unsqueeze (0 )), generator [i ]) for  i  in  range (batch_size )
430431            ]
431432        else :
432433            image_latents  =  [retrieve_latents (self .vae .encode (img .unsqueeze (0 )), generator ) for  img  in  image ]
433434
435+         image_latents  =  torch .cat (image_latents , dim = 0 ).to (dtype )
436+ 
434437        num_latent_frames  =  (num_frames  -  1 ) //  self .vae_scale_factor_temporal  +  1 
435438        latent_height , latent_width  =  height  //  self .vae_scale_factor_spatial , width  //  self .vae_scale_factor_spatial 
436439        shape  =  (batch_size , num_channels_latents , num_latent_frames , latent_height , latent_width )
@@ -502,18 +505,24 @@ def __call__(
502505        image : PipelineImageInput ,
503506        prompt : Union [str , List [str ]] =  None ,
504507        prompt_2 : Union [str , List [str ]] =  None ,
508+         negative_prompt : Union [str , List [str ]] =  None ,
509+         negative_prompt_2 : Union [str , List [str ]] =  None ,
505510        height : int  =  544 ,
506511        width : int  =  960 ,
507512        num_frames : int  =  97 ,
508513        num_inference_steps : int  =  50 ,
509514        sigmas : List [float ] =  None ,
510-         guidance_scale : float  =  6.0 ,
515+         true_cfg_scale : float  =  6.0 ,
516+         guidance_scale : float  =  1.0 ,
511517        num_videos_per_prompt : Optional [int ] =  1 ,
512518        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
513519        latents : Optional [torch .Tensor ] =  None ,
514520        prompt_embeds : Optional [torch .Tensor ] =  None ,
515521        pooled_prompt_embeds : Optional [torch .Tensor ] =  None ,
516522        prompt_attention_mask : Optional [torch .Tensor ] =  None ,
523+         negative_prompt_embeds : Optional [torch .Tensor ] =  None ,
524+         negative_pooled_prompt_embeds : Optional [torch .Tensor ] =  None ,
525+         negative_prompt_attention_mask : Optional [torch .Tensor ] =  None ,
517526        output_type : Optional [str ] =  "pil" ,
518527        return_dict : bool  =  True ,
519528        attention_kwargs : Optional [Dict [str , Any ]] =  None ,
@@ -534,6 +543,13 @@ def __call__(
534543            prompt_2 (`str` or `List[str]`, *optional*): 
535544                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 
536545                will be used instead. 
546+             negative_prompt (`str` or `List[str]`, *optional*): 
547+                 The prompt or prompts not to guide the image generation. If not defined, one has to pass 
548+                 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is 
549+                 not greater than `1`). 
550+             negative_prompt_2 (`str` or `List[str]`, *optional*): 
551+                 The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 
552+                 `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. 
537553            height (`int`, defaults to `720`): 
538554                The height in pixels of the generated image. 
539555            width (`int`, defaults to `1280`): 
@@ -547,6 +563,8 @@ def __call__(
547563                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 
548564                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 
549565                will be used. 
566+             true_cfg_scale (`float`, *optional*, defaults to 1.0): 
567+                 When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. 
550568            guidance_scale (`float`, defaults to `6.0`): 
551569                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 
552570                `guidance_scale` is defined as `w` of equation 2. of [Imagen 
@@ -567,6 +585,17 @@ def __call__(
567585            prompt_embeds (`torch.Tensor`, *optional*): 
568586                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 
569587                provided, text embeddings are generated from the `prompt` input argument. 
588+             pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 
589+                 Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 
590+                 If not provided, pooled text embeddings will be generated from `prompt` input argument. 
591+             negative_prompt_embeds (`torch.FloatTensor`, *optional*): 
592+                 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
593+                 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 
594+                 argument. 
595+             negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 
596+                 Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 
597+                 weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 
598+                 input argument. 
570599            output_type (`str`, *optional*, defaults to `"pil"`): 
571600                The output format of the generated image. Choose between `PIL.Image` or `np.array`. 
572601            return_dict (`bool`, *optional*, defaults to `True`): 
@@ -611,6 +640,11 @@ def __call__(
611640            prompt_template ,
612641        )
613642
643+         has_neg_prompt  =  negative_prompt  is  not   None  or  (
644+             negative_prompt_embeds  is  not   None  and  negative_pooled_prompt_embeds  is  not   None 
645+         )
646+         do_true_cfg  =  true_cfg_scale  >  1  and  has_neg_prompt 
647+ 
614648        self ._guidance_scale  =  guidance_scale 
615649        self ._attention_kwargs  =  attention_kwargs 
616650        self ._current_timestep  =  None 
@@ -627,6 +661,7 @@ def __call__(
627661            batch_size  =  prompt_embeds .shape [0 ]
628662
629663        # 3. Encode input prompt 
664+         transformer_dtype  =  self .transformer .dtype 
630665        prompt_embeds , pooled_prompt_embeds , prompt_attention_mask  =  self .encode_prompt (
631666            prompt = prompt ,
632667            prompt_2 = prompt_2 ,
@@ -638,21 +673,29 @@ def __call__(
638673            device = device ,
639674            max_sequence_length = max_sequence_length ,
640675        )
641- 
642-         transformer_dtype  =  self .transformer .dtype 
643676        prompt_embeds  =  prompt_embeds .to (transformer_dtype )
644677        prompt_attention_mask  =  prompt_attention_mask .to (transformer_dtype )
645-         if  pooled_prompt_embeds  is  not   None :
646-             pooled_prompt_embeds  =  pooled_prompt_embeds .to (transformer_dtype )
678+         pooled_prompt_embeds  =  pooled_prompt_embeds .to (transformer_dtype )
679+ 
680+         if  do_true_cfg :
681+             negative_prompt_embeds , negative_pooled_prompt_embeds , negative_prompt_attention_mask  =  self .encode_prompt (
682+                 prompt = negative_prompt ,
683+                 prompt_2 = negative_prompt_2 ,
684+                 prompt_template = prompt_template ,
685+                 num_videos_per_prompt = num_videos_per_prompt ,
686+                 prompt_embeds = negative_prompt_embeds ,
687+                 pooled_prompt_embeds = negative_pooled_prompt_embeds ,
688+                 prompt_attention_mask = negative_prompt_attention_mask ,
689+                 device = device ,
690+                 max_sequence_length = max_sequence_length ,
691+             )
692+             negative_prompt_embeds  =  negative_prompt_embeds .to (transformer_dtype )
693+             negative_prompt_attention_mask  =  negative_prompt_attention_mask .to (transformer_dtype )
694+             negative_pooled_prompt_embeds  =  negative_pooled_prompt_embeds .to (transformer_dtype )
647695
648696        # 4. Prepare timesteps 
649697        sigmas  =  np .linspace (1.0 , 0.0 , num_inference_steps  +  1 )[:- 1 ] if  sigmas  is  None  else  sigmas 
650-         timesteps , num_inference_steps  =  retrieve_timesteps (
651-             self .scheduler ,
652-             num_inference_steps ,
653-             device ,
654-             sigmas = sigmas ,
655-         )
698+         timesteps , num_inference_steps  =  retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas )
656699
657700        # 5. Prepare latent variables 
658701        vae_dtype  =  self .vae .dtype 
@@ -702,6 +745,19 @@ def __call__(
702745                    return_dict = False ,
703746                )[0 ]
704747
748+                 if  do_true_cfg :
749+                     neg_noise_pred  =  self .transformer (
750+                         hidden_states = latent_model_input ,
751+                         timestep = timestep ,
752+                         encoder_hidden_states = negative_prompt_embeds ,
753+                         encoder_attention_mask = negative_prompt_attention_mask ,
754+                         pooled_projections = negative_pooled_prompt_embeds ,
755+                         guidance = guidance ,
756+                         attention_kwargs = attention_kwargs ,
757+                         return_dict = False ,
758+                     )[0 ]
759+                     noise_pred  =  neg_noise_pred  +  true_cfg_scale  *  (noise_pred  -  neg_noise_pred )
760+ 
705761                # compute the previous noisy sample x_t -> x_t-1 
706762                latents  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
707763
0 commit comments