@@ -361,6 +361,7 @@ def __call__(
361361 # If an encoder_hidden_states_mask is provided, turn it into a broadcastable attention mask.
362362 # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding.
363363 # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend).
364+ joint_seq_lens = None
364365 if encoder_hidden_states_mask is not None and attention_mask is None :
365366 batch_size , image_seq_len = hidden_states .shape [:2 ]
366367 text_seq_len = encoder_hidden_states .shape [1 ]
@@ -385,6 +386,12 @@ def __call__(
385386 joint_attention_mask = torch .cat ([text_attention_mask , image_attention_mask ], dim = 1 )
386387 attention_mask = joint_attention_mask [:, None , None , :]
387388
389+ # For varlen flash attention, we need the JOINT sequence lengths (text + image), not just text
390+ if text_seq_lens is not None :
391+ # text_seq_lens contains per-sample text lengths
392+ # Add the image sequence length to get total joint sequence length
393+ joint_seq_lens = text_seq_lens + image_seq_len
394+
388395 # Compute joint attention
389396 joint_hidden_states = dispatch_attention_fn (
390397 joint_query ,
@@ -395,7 +402,7 @@ def __call__(
395402 is_causal = False ,
396403 backend = self ._attention_backend ,
397404 parallel_config = self ._parallel_config ,
398- seq_lens = text_seq_lens ,
405+ seq_lens = joint_seq_lens ,
399406 )
400407
401408 # Reshape back
0 commit comments