@@ -170,8 +170,14 @@ def forward(
170170 joint_query = apply_rope1 (joint_query , image_rotary_emb )
171171 joint_key = apply_rope1 (joint_key , image_rotary_emb )
172172
173+ if encoder_hidden_states_mask is not None :
174+ attn_mask = torch .zeros ((batch_size , 1 , seq_txt + seq_img ), dtype = hidden_states .dtype , device = hidden_states .device )
175+ attn_mask [:, 0 , :seq_txt ] = encoder_hidden_states_mask
176+ else :
177+ attn_mask = None
178+
173179 joint_hidden_states = optimized_attention_masked (joint_query , joint_key , joint_value , self .heads ,
174- attention_mask , transformer_options = transformer_options ,
180+ attn_mask , transformer_options = transformer_options ,
175181 skip_reshape = True )
176182
177183 txt_attn_output = joint_hidden_states [:, :seq_txt , :]
@@ -430,6 +436,9 @@ def _forward(
430436 encoder_hidden_states = context
431437 encoder_hidden_states_mask = attention_mask
432438
439+ if encoder_hidden_states_mask is not None and not torch .is_floating_point (encoder_hidden_states_mask ):
440+ encoder_hidden_states_mask = (encoder_hidden_states_mask - 1 ).to (x .dtype ) * torch .finfo (x .dtype ).max
441+
433442 hidden_states , img_ids , orig_shape = self .process_img (x )
434443 num_embeds = hidden_states .shape [1 ]
435444
0 commit comments