Skip to content

Commit 27d574f

Browse files
committed
Handle None joint_attention_kwargs in JointTransformerBlock
1 parent 89c4e63 commit 27d574f

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/diffusers/models/attention.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def forward(
192192
hidden_states: torch.FloatTensor,
193193
encoder_hidden_states: torch.FloatTensor,
194194
temb: torch.FloatTensor,
195-
joint_attention_kwargs: Dict[str, Any] = {},
195+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
196196
):
197197
if self.use_dual_attention:
198198
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -208,6 +208,10 @@ def forward(
208208
encoder_hidden_states, emb=temb
209209
)
210210

211+
# Empty dict if None is passed
212+
if joint_attention_kwargs is None:
213+
joint_attention_kwargs = {}
214+
211215
# Attention.
212216
attn_output, context_attn_output = self.attn(
213217
hidden_states=norm_hidden_states,

0 commit comments

Comments
 (0)