@@ -466,12 +466,15 @@ def __call__(
466466        self ,
467467        prompt : Union [str , List [str ]] =  None ,
468468        prompt_2 : Union [str , List [str ]] =  None ,
469+         negative_prompt : Union [str , List [str ]] =  None ,
470+         negative_prompt_2 : Union [str , List [str ]] =  None ,
469471        height : int  =  720 ,
470472        width : int  =  1280 ,
471473        num_frames : int  =  129 ,
472474        num_inference_steps : int  =  50 ,
473475        sigmas : List [float ] =  None ,
474476        guidance_scale : float  =  6.0 ,
477+         true_cfg_scale : float  =  1.0 ,
475478        num_videos_per_prompt : Optional [int ] =  1 ,
476479        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
477480        latents : Optional [torch .Tensor ] =  None ,
@@ -590,6 +593,7 @@ def __call__(
590593            batch_size  =  prompt_embeds .shape [0 ]
591594
592595        # 3. Encode input prompt 
596+         do_true_cfg  =  true_cfg_scale  >  1.0  and  negative_prompt  is  not None 
593597        prompt_embeds , pooled_prompt_embeds , prompt_attention_mask  =  self .encode_prompt (
594598            prompt = prompt ,
595599            prompt_2 = prompt_2 ,
@@ -601,12 +605,29 @@ def __call__(
601605            device = device ,
602606            max_sequence_length = max_sequence_length ,
603607        )
608+         if  do_true_cfg :
609+             negative_prompt_embeds , negative_pooled_prompt_embeds , negative_prompt_attention_mask  =  self .encode_prompt (
610+                 prompt = negative_prompt ,
611+                 prompt_2 = negative_prompt_2 ,
612+                 prompt_template = prompt_template ,
613+                 num_videos_per_prompt = num_videos_per_prompt ,
614+                 prompt_embeds = None ,
615+                 pooled_prompt_embeds = None ,
616+                 prompt_attention_mask = None ,
617+                 device = device ,
618+                 max_sequence_length = max_sequence_length ,
619+             )
604620
605621        transformer_dtype  =  self .transformer .dtype 
606622        prompt_embeds  =  prompt_embeds .to (transformer_dtype )
607623        prompt_attention_mask  =  prompt_attention_mask .to (transformer_dtype )
608624        if  pooled_prompt_embeds  is  not None :
609625            pooled_prompt_embeds  =  pooled_prompt_embeds .to (transformer_dtype )
626+         if  do_true_cfg :
627+             negative_prompt_embeds  =  negative_prompt_embeds .to (transformer_dtype )
628+             negative_prompt_attention_mask  =  negative_prompt_attention_mask .to (transformer_dtype )
629+             if  negative_pooled_prompt_embeds  is  not None :
630+                 negative_pooled_prompt_embeds  =  negative_pooled_prompt_embeds .to (transformer_dtype )
610631
611632        # 4. Prepare timesteps 
612633        sigmas  =  np .linspace (1.0 , 0.0 , num_inference_steps  +  1 )[:- 1 ] if  sigmas  is  None  else  sigmas 
@@ -658,6 +679,18 @@ def __call__(
658679                    attention_kwargs = attention_kwargs ,
659680                    return_dict = False ,
660681                )[0 ]
682+                 if  do_true_cfg :
683+                     neg_noise_pred  =  self .transformer (
684+                         hidden_states = latent_model_input ,
685+                         timestep = timestep ,
686+                         encoder_hidden_states = negative_prompt_embeds ,
687+                         encoder_attention_mask = negative_prompt_attention_mask ,
688+                         pooled_projections = negative_pooled_prompt_embeds ,
689+                         guidance = guidance ,
690+                         attention_kwargs = attention_kwargs ,
691+                         return_dict = False ,
692+                     )[0 ]
693+                     noise_pred  =  neg_noise_pred  +  true_cfg_scale  *  (noise_pred  -  neg_noise_pred )
661694
662695                # compute the previous noisy sample x_t -> x_t-1 
663696                latents  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments