@@ -694,9 +694,6 @@ def __call__(
694694 max_sequence_length = max_sequence_length ,
695695 lora_scale = lora_scale ,
696696 )
697-
698- if self .do_classifier_free_guidance :
699- prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
700697
701698 # 4. Prepare latent variables
702699 num_channels_latents = self .transformer .config .in_channels // 4
@@ -773,13 +770,11 @@ def __call__(
773770 if image_embeds is not None :
774771 self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = image_embeds
775772
776- # expand the latents if we are doing classifier free guidance
777- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
778773 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
779- timestep = t .expand (latent_model_input .shape [0 ]).to (latents .dtype )
774+ timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
780775
781776 noise_pred = self .transformer (
782- hidden_states = latent_model_input ,
777+ hidden_states = latents ,
783778 timestep = timestep / 1000 ,
784779 encoder_hidden_states = prompt_embeds ,
785780 txt_ids = text_ids ,
@@ -791,8 +786,16 @@ def __call__(
791786 if self .do_classifier_free_guidance :
792787 if negative_image_embeds is not None :
793788 self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
794- noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
795- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
789+ neg_noise_pred = self .transformer (
790+ hidden_states = latents ,
791+ timestep = timestep / 1000 ,
792+ encoder_hidden_states = negative_prompt_embeds ,
793+ txt_ids = negative_text_ids ,
794+ img_ids = latent_image_ids ,
795+ joint_attention_kwargs = self .joint_attention_kwargs ,
796+ return_dict = False ,
797+ )[0 ]
798+ noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred )
796799
797800 # compute the previous noisy sample x_t -> x_t-1
798801 latents_dtype = latents .dtype
0 commit comments