@@ -318,6 +318,7 @@ def prepare_latents(
318318        height : int  =  704 ,
319319        width : int  =  1280 ,
320320        num_frames : int  =  121 ,
321+         do_classifier_free_guidance : bool  =  True ,
321322        input_frames_guidance : bool  =  False ,
322323        dtype : Optional [torch .dtype ] =  None ,
323324        device : Optional [torch .device ] =  None ,
@@ -331,11 +332,12 @@ def prepare_latents(
331332            )
332333
333334        num_cond_frames  =  video .size (2 )
334-         num_cond_latent_frames  =  (num_cond_frames  -  1 ) //  self .vae_scale_factor_temporal  +  1 
335335        if  num_cond_frames  >=  num_frames :
336336            # Take the last `num_frames` frames for conditioning 
337+             num_cond_latent_frames  =  (num_frames  -  1 ) //  self .vae_scale_factor_temporal  +  1 
337338            video  =  video [:, :, - num_frames :]
338339        else :
340+             num_cond_latent_frames  =  (num_cond_frames  -  1 ) //  self .vae_scale_factor_temporal  +  1 
339341            num_padding_frames  =  num_frames  -  num_cond_frames 
340342            padding  =  video .new_zeros (video .size (0 ), video .size (1 ), num_padding_frames , video .size (3 ), video .size (4 ))
341343            video  =  torch .cat ([video , padding ], dim = 2 )
@@ -374,22 +376,25 @@ def prepare_latents(
374376        if  latents  is  None :
375377            latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
376378        else :
377-             latents  =  latents .to (device = device , dtype = dtype )  *   self . scheduler . config . sigma_max 
379+             latents  =  latents .to (device = device , dtype = dtype )
378380
379381        latents  =  latents  *  self .scheduler .config .sigma_max 
380382
381-         cond_indicator  =  latents .new_zeros (1 , 1 , latents .size (2 ), 1 , 1 )
382-         uncond_indicator  =  latents .new_zeros (1 , 1 , latents .size (2 ), 1 , 1 )
383-         cond_indicator [:, :, :num_cond_latent_frames ] =  1.0 
384-         uncond_indicator [:, :, :num_cond_latent_frames ] =  1.0 
385- 
386383        padding_shape  =  (batch_size , 1 , num_latent_frames , latent_height , latent_width )
387384        ones_padding  =  latents .new_ones (padding_shape )
388385        zeros_padding  =  latents .new_zeros (padding_shape )
386+ 
387+         cond_indicator  =  latents .new_zeros (1 , 1 , latents .size (2 ), 1 , 1 )
388+         cond_indicator [:, :, :num_cond_latent_frames ] =  1.0 
389389        cond_mask  =  cond_indicator  *  ones_padding  +  (1  -  cond_indicator ) *  zeros_padding 
390-         uncond_mask  =  zeros_padding 
391-         if  input_frames_guidance :
392-             uncond_mask  =  uncond_indicator  *  ones_padding  +  (1  -  uncond_indicator ) *  zeros_padding 
390+ 
391+         uncond_indicator  =  uncond_mask  =  None 
392+         if  do_classifier_free_guidance :
393+             uncond_indicator  =  latents .new_zeros (1 , 1 , latents .size (2 ), 1 , 1 )
394+             uncond_indicator [:, :, :num_cond_latent_frames ] =  1.0 
395+             uncond_mask  =  zeros_padding 
396+             if  not  input_frames_guidance :
397+                 uncond_mask  =  uncond_indicator  *  ones_padding  +  (1  -  uncond_indicator ) *  zeros_padding 
393398
394399        return  latents , init_latents , cond_indicator , uncond_indicator , cond_mask , uncond_mask 
395400
@@ -599,24 +604,24 @@ def __call__(
599604            height ,
600605            width ,
601606            num_frames ,
607+             self .do_classifier_free_guidance ,
602608            input_frames_guidance ,
603609            torch .float32 ,
604610            device ,
605611            generator ,
606612            latents ,
607613        )
608-         uncond_mask  =  uncond_mask .to (transformer_dtype )
609614        cond_mask  =  cond_mask .to (transformer_dtype )
615+         if  self .do_classifier_free_guidance :
616+             uncond_mask  =  uncond_mask .to (transformer_dtype )
617+ 
610618        augment_sigma  =  torch .tensor ([augment_sigma ], device = device , dtype = torch .float32 )
611619        padding_mask  =  latents .new_zeros (1 , 1 , height , width , dtype = transformer_dtype )
612620
613621        # 6. Denoising loop 
614622        num_warmup_steps  =  len (timesteps ) -  num_inference_steps  *  self .scheduler .order 
615623        self ._num_timesteps  =  len (timesteps )
616624
617-         if  not  guidance_scale  >  1.0 :
618-             raise  ValueError ("Running inference without CFG is not yet supported. Please set `guidance_scale > 1`." )
619- 
620625        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
621626            for  i , t  in  enumerate (timesteps ):
622627                if  self .interrupt :
@@ -628,31 +633,14 @@ def __call__(
628633                current_sigma  =  self .scheduler .sigmas [i ]
629634                is_augment_sigma_greater  =  augment_sigma  >=  current_sigma 
630635
631-                 current_uncond_indicator  =  uncond_indicator  *  0  if  is_augment_sigma_greater  else  uncond_indicator 
632-                 uncond_noise  =  randn_tensor (latents .shape , generator = generator , device = device , dtype = torch .float32 )
633-                 uncond_latent  =  conditioning_latents  +  uncond_noise  *  augment_sigma [:, None , None , None , None ]
634-                 uncond_latent  =  self .scheduler .scale_model_input (uncond_latent , t )
635-                 uncond_latent  =  current_uncond_indicator  *  uncond_latent  +  (1  -  current_uncond_indicator ) *  latents 
636- 
637636                current_cond_indicator  =  cond_indicator  *  0  if  is_augment_sigma_greater  else  cond_indicator 
638637                cond_noise  =  randn_tensor (latents .shape , generator = generator , device = device , dtype = torch .float32 )
639638                cond_latent  =  conditioning_latents  +  cond_noise  *  augment_sigma [:, None , None , None , None ]
640-                 cond_latent  =  self .scheduler .scale_model_input (cond_latent , t )
641639                cond_latent  =  current_cond_indicator  *  cond_latent  +  (1  -  current_cond_indicator ) *  latents 
642- 
643-                 uncond_latent  =  uncond_latent .to (transformer_dtype )
640+                 cond_latent  =  self .scheduler .scale_model_input (cond_latent , t )
644641                cond_latent  =  cond_latent .to (transformer_dtype )
645642
646-                 noise_pred_uncond  =  self .transformer (
647-                     hidden_states = uncond_latent ,
648-                     timestep = timestep ,
649-                     encoder_hidden_states = negative_prompt_embeds ,
650-                     fps = fps ,
651-                     condition_mask = uncond_mask ,
652-                     padding_mask = padding_mask ,
653-                     return_dict = False ,
654-                 )[0 ]
655-                 noise_pred_cond  =  self .transformer (
643+                 noise_pred  =  self .transformer (
656644                    hidden_states = cond_latent ,
657645                    timestep = timestep ,
658646                    encoder_hidden_states = prompt_embeds ,
@@ -662,18 +650,48 @@ def __call__(
662650                    return_dict = False ,
663651                )[0 ]
664652
665-                 noise_pred  =  torch .cat ([noise_pred_uncond , noise_pred_cond ], dim = 0 )
666-                 noise_pred  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
667-                 noise_pred_uncond , noise_pred_cond  =  noise_pred .chunk (2 , dim = 0 )
668- 
669-                 noise_pred_cond  =  (
670-                     current_cond_indicator  *  conditioning_latents  +  (1  -  current_cond_indicator ) *  noise_pred_cond 
671-                 )
672-                 noise_pred_uncond  =  (
673-                     current_uncond_indicator  *  conditioning_latents 
674-                     +  (1  -  current_uncond_indicator ) *  noise_pred_uncond 
675-                 )
676-                 latents  =  noise_pred_cond  +  self .guidance_scale  *  (noise_pred_cond  -  noise_pred_uncond )
653+                 if  self .do_classifier_free_guidance :
654+                     current_uncond_indicator  =  uncond_indicator  *  0  if  is_augment_sigma_greater  else  uncond_indicator 
655+                     uncond_noise  =  randn_tensor (latents .shape , generator = generator , device = device , dtype = torch .float32 )
656+                     uncond_latent  =  conditioning_latents  +  uncond_noise  *  augment_sigma [:, None , None , None , None ]
657+                     uncond_latent  =  current_uncond_indicator  *  uncond_latent  +  (1  -  current_uncond_indicator ) *  latents 
658+                     uncond_latent  =  self .scheduler .scale_model_input (uncond_latent , t )
659+                     uncond_latent  =  uncond_latent .to (transformer_dtype )
660+ 
661+                     noise_pred_uncond  =  self .transformer (
662+                         hidden_states = uncond_latent ,
663+                         timestep = timestep ,
664+                         encoder_hidden_states = negative_prompt_embeds ,
665+                         fps = fps ,
666+                         condition_mask = uncond_mask ,
667+                         padding_mask = padding_mask ,
668+                         return_dict = False ,
669+                     )[0 ]
670+                     noise_pred  =  torch .cat ([noise_pred_uncond , noise_pred ])
671+ 
672+                 # pred_original_sample (x0) 
673+                 noise_pred  =  self .scheduler .step (noise_pred , t , latents , return_dict = False )[1 ]
674+                 self .scheduler ._step_index  -=  1 
675+ 
676+                 if  self .do_classifier_free_guidance :
677+                     noise_pred_uncond , noise_pred_cond  =  noise_pred .chunk (2 , dim = 0 )
678+                     noise_pred_uncond  =  (
679+                         current_uncond_indicator  *  conditioning_latents 
680+                         +  (1  -  current_uncond_indicator ) *  noise_pred_uncond 
681+                     )
682+                     noise_pred_cond  =  (
683+                         current_cond_indicator  *  conditioning_latents  +  (1  -  current_cond_indicator ) *  noise_pred_cond 
684+                     )
685+                     noise_pred  =  noise_pred_cond  +  self .guidance_scale  *  (noise_pred_cond  -  noise_pred_uncond )
686+                 else :
687+                     noise_pred  =  (
688+                         current_cond_indicator  *  conditioning_latents  +  (1  -  current_cond_indicator ) *  noise_pred 
689+                     )
690+ 
691+                 # pred_sample (eps) 
692+                 latents  =  self .scheduler .step (
693+                     noise_pred , t , latents , return_dict = False , pred_original_sample = noise_pred 
694+                 )[0 ]
677695
678696                if  callback_on_step_end  is  not None :
679697                    callback_kwargs  =  {}
@@ -683,6 +701,7 @@ def __call__(
683701
684702                    latents  =  callback_outputs .pop ("latents" , latents )
685703                    prompt_embeds  =  callback_outputs .pop ("prompt_embeds" , prompt_embeds )
704+                     negative_prompt_embeds  =  callback_outputs .pop ("negative_prompt_embeds" , negative_prompt_embeds )
686705
687706                # call the callback, if provided 
688707                if  i  ==  len (timesteps ) -  1  or  ((i  +  1 ) >  num_warmup_steps  and  (i  +  1 ) %  self .scheduler .order  ==  0 ):
0 commit comments