@@ -83,14 +83,16 @@ def forward(
8383 hidden_states : torch .FloatTensor ,
8484 temb : torch .FloatTensor ,
8585 image_rotary_emb = None ,
86+ joint_attention_kwargs = None ,
8687 ):
8788 residual = hidden_states
8889 norm_hidden_states , gate = self .norm (hidden_states , emb = temb )
8990 mlp_hidden_states = self .act_mlp (self .proj_mlp (norm_hidden_states ))
90-
91+ joint_attention_kwargs = joint_attention_kwargs or {}
9192 attn_output = self .attn (
9293 hidden_states = norm_hidden_states ,
9394 image_rotary_emb = image_rotary_emb ,
95+ ** joint_attention_kwargs ,
9496 )
9597
9698 hidden_states = torch .cat ([attn_output , mlp_hidden_states ], dim = 2 )
@@ -161,18 +163,20 @@ def forward(
161163 encoder_hidden_states : torch .FloatTensor ,
162164 temb : torch .FloatTensor ,
163165 image_rotary_emb = None ,
166+ joint_attention_kwargs = None ,
164167 ):
165168 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
166169
167170 norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
168171 encoder_hidden_states , emb = temb
169172 )
170-
173+ joint_attention_kwargs = joint_attention_kwargs or {}
171174 # Attention.
172175 attn_output , context_attn_output = self .attn (
173176 hidden_states = norm_hidden_states ,
174177 encoder_hidden_states = norm_encoder_hidden_states ,
175178 image_rotary_emb = image_rotary_emb ,
179+ ** joint_attention_kwargs ,
176180 )
177181
178182 # Process attention outputs for the `hidden_states`.
@@ -497,6 +501,7 @@ def custom_forward(*inputs):
497501 encoder_hidden_states = encoder_hidden_states ,
498502 temb = temb ,
499503 image_rotary_emb = image_rotary_emb ,
504+ joint_attention_kwargs = joint_attention_kwargs ,
500505 )
501506
502507 # controlnet residual
@@ -533,6 +538,7 @@ def custom_forward(*inputs):
533538 hidden_states = hidden_states ,
534539 temb = temb ,
535540 image_rotary_emb = image_rotary_emb ,
541+ joint_attention_kwargs = joint_attention_kwargs ,
536542 )
537543
538544 # controlnet residual
0 commit comments