@@ -242,6 +242,7 @@ def encode_prompt(
242242 self ,
243243 prompt : Union [str , List [str ]],
244244 negative_prompt : Optional [Union [str , List [str ]]] = None ,
245+ do_classifier_free_guidance : bool = False ,
245246 num_videos_per_prompt : int = 1 ,
246247 prompt_embeds : Optional [torch .Tensor ] = None ,
247248 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
@@ -259,6 +260,8 @@ def encode_prompt(
259260 The prompt or prompts not to guide the image generation. If not defined, one has to pass
260261 `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
261262 less than `1`).
263+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
264+ Whether to use classifier free guidance or not.
262265 num_videos_per_prompt (`int`, *optional*, defaults to 1):
263266 Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
264267 prompt_embeds (`torch.Tensor`, *optional*):
@@ -290,7 +293,7 @@ def encode_prompt(
290293 dtype = dtype ,
291294 )
292295
293- if negative_prompt_embeds is None :
296+ if do_classifier_free_guidance and negative_prompt_embeds is None :
294297 negative_prompt = negative_prompt or ""
295298 negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
296299
@@ -439,6 +442,10 @@ def prepare_latents(
439442 def guidance_scale (self ):
440443 return self ._guidance_scale
441444
445+ @property
446+ def do_classifier_free_guidance (self ):
447+ return self ._guidance_scale > 1
448+
442449 @property
443450 def num_timesteps (self ):
444451 return self ._num_timesteps
@@ -468,15 +475,13 @@ def __call__(
468475 latents : Optional [torch .Tensor ] = None ,
469476 prompt_embeds : Optional [torch .Tensor ] = None ,
470477 negative_prompt_embeds : Optional [torch .Tensor ] = None ,
471- prompt_attention_mask : Optional [torch .Tensor ] = None ,
472478 output_type : Optional [str ] = "np" ,
473479 return_dict : bool = True ,
474480 callback_on_step_end : Optional [
475481 Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
476482 ] = None ,
477483 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
478484 max_sequence_length : int = 512 ,
479- autocast_dtype : torch .dtype = torch .bfloat16 ,
480485 ):
481486 r"""
482487 The call function to the pipeline for generation.
@@ -571,20 +576,22 @@ def __call__(
571576 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
572577 prompt = prompt ,
573578 negative_prompt = negative_prompt ,
579+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
574580 num_videos_per_prompt = num_videos_per_prompt ,
575581 prompt_embeds = prompt_embeds ,
576582 negative_prompt_embeds = negative_prompt_embeds ,
577583 max_sequence_length = max_sequence_length ,
578584 device = device ,
579- dtype = autocast_dtype ,
580585 )
581- # encode image embedding
586+
587+ # Encode image embedding
582588 image_embeds = self .encode_image (image )
583589 image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
584590
585- prompt_embeds = prompt_embeds .to (autocast_dtype )
586- negative_prompt_embeds = negative_prompt_embeds .to (autocast_dtype )
587- image_embeds = image_embeds .to (autocast_dtype )
591+ transformer_dtype = self .transformer .dtype
592+ prompt_embeds = prompt_embeds .to (transformer_dtype )
593+ negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
594+ image_embeds = image_embeds .to (transformer_dtype )
588595
589596 # 4. Prepare timesteps
590597 self .scheduler .flow_shift = flow_shift
@@ -596,6 +603,7 @@ def __call__(
596603 height , width = image .shape [- 2 :]
597604 else :
598605 width , height = image .size
606+
599607 # 5. Prepare latent variables
600608 num_channels_latents = self .vae .config .z_dim
601609 num_latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
@@ -618,37 +626,32 @@ def __call__(
618626 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
619627 self ._num_timesteps = len (timesteps )
620628
621- with (
622- self .progress_bar (total = num_inference_steps ) as progress_bar ,
623- amp .autocast ('cuda' , dtype = autocast_dtype , cache_enabled = False )
624- ):
629+ with self .progress_bar (total = num_inference_steps ) as progress_bar :
625630 for i , t in enumerate (timesteps ):
626631 if self .interrupt :
627632 continue
628633
629634 self ._current_timestep = t
630- latent_model_input = latents
631- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
632- timestep = t .expand (latents .shape [0 ])
635+ latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
636+ timestep = t .expand (latents .shape [0 ]).to (transformer_dtype )
633637
634- noise_pred = self .transformer (
635- hidden_states = torch . concat ([ latent_model_input , condition ], dim = 1 ) ,
638+ noise_cond = self .transformer (
639+ hidden_states = latent_model_input ,
636640 timestep = timestep ,
637641 encoder_hidden_states = prompt_embeds ,
638642 encoder_hidden_states_image = image_embeds ,
639643 return_dict = False ,
640644 )[0 ]
641645
642- noise_pred_negative = self .transformer (
643- hidden_states = torch .concat ([latent_model_input , condition ], dim = 1 ),
644- timestep = timestep ,
645- encoder_hidden_states = negative_prompt_embeds ,
646- encoder_hidden_states_image = image_embeds ,
647- return_dict = False ,
648- )[0 ]
649-
650- noise_pred = noise_pred_negative + guidance_scale * (
651- noise_pred - noise_pred_negative )
646+ if self .do_classifier_free_guidance :
647+ noise_uncond = self .transformer (
648+ hidden_states = latent_model_input ,
649+ timestep = timestep ,
650+ encoder_hidden_states = negative_prompt_embeds ,
651+ encoder_hidden_states_image = image_embeds ,
652+ return_dict = False ,
653+ )[0 ]
654+ noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond )
652655
653656 # compute the previous noisy sample x_t -> x_t-1
654657 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments