@@ -750,25 +750,27 @@ def __call__(
750
750
latent_model_input = torch .cat ([latents , condition ], dim = 1 ).to (transformer_dtype )
751
751
timestep = t .expand (latents .shape [0 ])
752
752
753
- noise_pred = current_model (
754
- hidden_states = latent_model_input ,
755
- timestep = timestep ,
756
- encoder_hidden_states = prompt_embeds ,
757
- encoder_hidden_states_image = image_embeds ,
758
- attention_kwargs = attention_kwargs ,
759
- return_dict = False ,
760
- )[0 ]
761
-
762
- if self .do_classifier_free_guidance :
763
- noise_uncond = current_model (
753
+ with current_model .cache_context ("cond" ):
754
+ noise_pred = current_model (
764
755
hidden_states = latent_model_input ,
765
756
timestep = timestep ,
766
- encoder_hidden_states = negative_prompt_embeds ,
757
+ encoder_hidden_states = prompt_embeds ,
767
758
encoder_hidden_states_image = image_embeds ,
768
759
attention_kwargs = attention_kwargs ,
769
760
return_dict = False ,
770
761
)[0 ]
771
- noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
762
+
763
+ if self .do_classifier_free_guidance :
764
+ with current_model .cache_context ("uncond" ):
765
+ noise_uncond = current_model (
766
+ hidden_states = latent_model_input ,
767
+ timestep = timestep ,
768
+ encoder_hidden_states = negative_prompt_embeds ,
769
+ encoder_hidden_states_image = image_embeds ,
770
+ attention_kwargs = attention_kwargs ,
771
+ return_dict = False ,
772
+ )[0 ]
773
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond )
772
774
773
775
# compute the previous noisy sample x_t -> x_t-1
774
776
latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments