@@ -738,7 +738,7 @@ def __call__(
738738 ip_adapter_image_embeds : Optional [List [torch .Tensor ]] = None ,
739739 output_type : Optional [str ] = "pil" ,
740740 return_dict : bool = True ,
741- joint_attention_kwargs : Dict [str , Any ] = {} ,
741+ joint_attention_kwargs : Optional [ Dict [str , Any ]] = None ,
742742 clip_skip : Optional [int ] = None ,
743743 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
744744 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
@@ -980,22 +980,22 @@ def __call__(
980980 )
981981
982982 image_prompt_embeds = dict (
983- ip_hidden_states = ip_hidden_states ,
984- temb = temb
983+ ip_hidden_states = ip_hidden_states ,
984+ temb = temb
985985 )
986- else :
987- image_prompt_embeds = {}
986+
987+ if self .joint_attention_kwargs is None :
988+ self ._joint_attention_kwargs = image_prompt_embeds
989+ else :
990+ self ._joint_attention_kwargs .update (** image_prompt_embeds )
988991
989992 noise_pred = self .transformer (
990993 hidden_states = latent_model_input ,
991994 timestep = timestep ,
992995 encoder_hidden_states = prompt_embeds ,
993996 pooled_projections = pooled_prompt_embeds ,
994997 return_dict = False ,
995- joint_attention_kwargs = {
996- ** image_prompt_embeds ,
997- ** self .joint_attention_kwargs ,
998- }
998+ joint_attention_kwargs = self .joint_attention_kwargs ,
999999 )[0 ]
10001000
10011001 # perform guidance
@@ -1016,10 +1016,7 @@ def __call__(
10161016 timestep = timestep ,
10171017 encoder_hidden_states = original_prompt_embeds ,
10181018 pooled_projections = original_pooled_prompt_embeds ,
1019- joint_attention_kwargs = {
1020- ** image_prompt_embeds ,
1021- ** self .joint_attention_kwargs ,
1022- },
1019+ joint_attention_kwargs = self .joint_attention_kwargs ,
10231020 return_dict = False ,
10241021 skip_layers = skip_guidance_layers ,
10251022 )[0 ]
0 commit comments