@@ -120,8 +120,10 @@ def forward(
120120 encoder_hidden_states : torch .Tensor ,
121121 temb : torch .Tensor ,
122122 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
123+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
123124 ) -> torch .Tensor :
124125 text_seq_length = encoder_hidden_states .size (1 )
126+ attention_kwargs = attention_kwargs or {}
125127
126128 # norm & modulate
127129 norm_hidden_states , norm_encoder_hidden_states , gate_msa , enc_gate_msa = self .norm1 (
@@ -133,6 +135,7 @@ def forward(
133135 hidden_states = norm_hidden_states ,
134136 encoder_hidden_states = norm_encoder_hidden_states ,
135137 image_rotary_emb = image_rotary_emb ,
138+ ** attention_kwargs ,
136139 )
137140
138141 hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -498,6 +501,7 @@ def custom_forward(*inputs):
498501 encoder_hidden_states ,
499502 emb ,
500503 image_rotary_emb ,
504+ attention_kwargs ,
501505 ** ckpt_kwargs ,
502506 )
503507 else :
@@ -506,6 +510,7 @@ def custom_forward(*inputs):
506510 encoder_hidden_states = encoder_hidden_states ,
507511 temb = emb ,
508512 image_rotary_emb = image_rotary_emb ,
513+ attention_kwargs = attention_kwargs ,
509514 )
510515
511516 if not self .config .use_rotary_positional_embeddings :
0 commit comments