diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02ed1f965abf..41c3f50c92d4 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): self._chunk_dim = dim def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ): + joint_attention_kwargs = joint_attention_kwargs or {} if self.use_dual_attention: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( hidden_states, emb=temb @@ -206,7 +211,9 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -214,7 +221,7 @@ def forward( hidden_states = hidden_states + attn_output if self.use_dual_attention: - attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 hidden_states = hidden_states + attn_output2 diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79452bb85176..79c4069e9a37 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -411,11 +411,15 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states, temb, + joint_attention_kwargs, **ckpt_kwargs, ) elif not is_skip: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual