@@ -750,27 +750,25 @@ def __call__(
750750 latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
751751 timestep = t .expand (latents .shape [0 ])
752752
753- with current_model .cache_context ("cond" ):
754- noise_pred = current_model (
753+
754+ noise_pred = current_model (
755+ hidden_states = latent_model_input ,
756+ timestep = timestep ,
757+ encoder_hidden_states = prompt_embeds ,
758+ encoder_hidden_states_image = image_embeds ,
759+ attention_kwargs = attention_kwargs ,
760+ return_dict = False ,)[0 ]
761+
762+ if self .do_classifier_free_guidance :
763+ noise_uncond = current_model (
755764 hidden_states = latent_model_input ,
756765 timestep = timestep ,
757- encoder_hidden_states = prompt_embeds ,
766+ encoder_hidden_states = negative_prompt_embeds ,
758767 encoder_hidden_states_image = image_embeds ,
759768 attention_kwargs = attention_kwargs ,
760769 return_dict = False ,
761770 )[0 ]
762-
763- if self .do_classifier_free_guidance :
764- with current_model .cache_context ("uncond" ):
765- noise_uncond = current_model (
766- hidden_states = latent_model_input ,
767- timestep = timestep ,
768- encoder_hidden_states = negative_prompt_embeds ,
769- encoder_hidden_states_image = image_embeds ,
770- attention_kwargs = attention_kwargs ,
771- return_dict = False ,
772- )[0 ]
773- noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
771+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
774772
775773 # compute the previous noisy sample x_t -> x_t-1
776774 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments