@@ -906,30 +906,33 @@ def __call__(
906906 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
907907 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
908908
909- noise_pred = self .transformer (
910- hidden_states = latents ,
911- timestep = timestep / 1000 ,
912- encoder_hidden_states = prompt_embeds ,
913- txt_ids = text_ids ,
914- img_ids = latent_image_ids ,
915- attention_mask = attention_mask ,
916- joint_attention_kwargs = self .joint_attention_kwargs ,
917- return_dict = False ,
918- )[0 ]
919-
920- if self .do_classifier_free_guidance :
921- if negative_image_embeds is not None :
922- self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
923- neg_noise_pred = self .transformer (
909+ with self .transformer .cache_context ("cond" ):
910+ noise_pred = self .transformer (
924911 hidden_states = latents ,
925912 timestep = timestep / 1000 ,
926- encoder_hidden_states = negative_prompt_embeds ,
927- txt_ids = negative_text_ids ,
913+ encoder_hidden_states = prompt_embeds ,
914+ txt_ids = text_ids ,
928915 img_ids = latent_image_ids ,
929- attention_mask = negative_attention_mask ,
916+ attention_mask = attention_mask ,
930917 joint_attention_kwargs = self .joint_attention_kwargs ,
931918 return_dict = False ,
932919 )[0 ]
920+
921+ if self .do_classifier_free_guidance :
922+ if negative_image_embeds is not None :
923+ self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
924+
925+ with self .transformer .cache_context ("uncond" ):
926+ neg_noise_pred = self .transformer (
927+ hidden_states = latents ,
928+ timestep = timestep / 1000 ,
929+ encoder_hidden_states = negative_prompt_embeds ,
930+ txt_ids = negative_text_ids ,
931+ img_ids = latent_image_ids ,
932+ attention_mask = negative_attention_mask ,
933+ joint_attention_kwargs = self .joint_attention_kwargs ,
934+ return_dict = False ,
935+ )[0 ]
933936 noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred )
934937
935938 # compute the previous noisy sample x_t -> x_t-1
0 commit comments