@@ -750,25 +750,27 @@ def __call__(
750750 latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
751751 timestep = t .expand (latents .shape [0 ])
752752
753- noise_pred = current_model (
754- hidden_states = latent_model_input ,
755- timestep = timestep ,
756- encoder_hidden_states = prompt_embeds ,
757- encoder_hidden_states_image = image_embeds ,
758- attention_kwargs = attention_kwargs ,
759- return_dict = False ,
760- )[0 ]
761-
762- if self .do_classifier_free_guidance :
763- noise_uncond = current_model (
753+ with current_model .cache_context ("cond" ):
754+ noise_pred = current_model (
764755 hidden_states = latent_model_input ,
765756 timestep = timestep ,
766- encoder_hidden_states = negative_prompt_embeds ,
757+ encoder_hidden_states = prompt_embeds ,
767758 encoder_hidden_states_image = image_embeds ,
768759 attention_kwargs = attention_kwargs ,
769760 return_dict = False ,
770761 )[0 ]
771- noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
762+
763+ if self .do_classifier_free_guidance :
764+ with current_model .cache_context ("uncond" ):
765+ noise_uncond = current_model (
766+ hidden_states = latent_model_input ,
767+ timestep = timestep ,
768+ encoder_hidden_states = negative_prompt_embeds ,
769+ encoder_hidden_states_image = image_embeds ,
770+ attention_kwargs = attention_kwargs ,
771+ return_dict = False ,
772+ )[0 ]
773+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
772774
773775 # compute the previous noisy sample x_t -> x_t-1
774776 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments