Skip to content

Commit 4383175

Browse files
committed
Reverted joint_attention_kwargs default for consistency
1 parent 5aed1d3 commit 4383175

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def __call__(
738738
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
739739
output_type: Optional[str] = "pil",
740740
return_dict: bool = True,
741-
joint_attention_kwargs: Dict[str, Any] = {},
741+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
742742
clip_skip: Optional[int] = None,
743743
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
744744
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
@@ -980,22 +980,22 @@ def __call__(
980980
)
981981

982982
image_prompt_embeds = dict(
983-
ip_hidden_states = ip_hidden_states,
984-
temb = temb
983+
ip_hidden_states=ip_hidden_states,
984+
temb=temb
985985
)
986-
else:
987-
image_prompt_embeds = {}
986+
987+
if self.joint_attention_kwargs is None:
988+
self._joint_attention_kwargs = image_prompt_embeds
989+
else:
990+
self._joint_attention_kwargs.update(**image_prompt_embeds)
988991

989992
noise_pred = self.transformer(
990993
hidden_states=latent_model_input,
991994
timestep=timestep,
992995
encoder_hidden_states=prompt_embeds,
993996
pooled_projections=pooled_prompt_embeds,
994997
return_dict=False,
995-
joint_attention_kwargs={
996-
**image_prompt_embeds,
997-
**self.joint_attention_kwargs,
998-
}
998+
joint_attention_kwargs=self.joint_attention_kwargs,
999999
)[0]
10001000

10011001
# perform guidance
@@ -1016,10 +1016,7 @@ def __call__(
10161016
timestep=timestep,
10171017
encoder_hidden_states=original_prompt_embeds,
10181018
pooled_projections=original_pooled_prompt_embeds,
1019-
joint_attention_kwargs={
1020-
**image_prompt_embeds,
1021-
**self.joint_attention_kwargs,
1022-
},
1019+
joint_attention_kwargs=self.joint_attention_kwargs,
10231020
return_dict=False,
10241021
skip_layers=skip_guidance_layers,
10251022
)[0]

0 commit comments

Comments
 (0)