@@ -150,7 +150,7 @@ def compute_text_seq_len_from_mask(
150150 """
151151 batch_size , text_seq_len = encoder_hidden_states .shape [:2 ]
152152 if encoder_hidden_states_mask is None :
153- return text_seq_len , None , None
153+ return text_seq_len , [ text_seq_len ] * batch_size , None
154154
155155 if encoder_hidden_states_mask .shape [:2 ] != (batch_size , text_seq_len ):
156156 raise ValueError (
@@ -165,7 +165,7 @@ def compute_text_seq_len_from_mask(
165165 active_positions = torch .where (encoder_hidden_states_mask , position_ids , position_ids .new_zeros (()))
166166 has_active = encoder_hidden_states_mask .any (dim = 1 )
167167 per_sample_len = torch .where (has_active , active_positions .max (dim = 1 ).values + 1 , torch .as_tensor (text_seq_len ))
168- return text_seq_len , per_sample_len , encoder_hidden_states_mask
168+ return text_seq_len , per_sample_len . tolist () , encoder_hidden_states_mask
169169
170170
171171class QwenTimestepProjEmbeddings (nn .Module ):
@@ -492,6 +492,7 @@ def __call__(
492492 encoder_hidden_states_mask : torch .FloatTensor = None ,
493493 attention_mask : Optional [torch .FloatTensor ] = None ,
494494 image_rotary_emb : Optional [torch .Tensor ] = None ,
495+ encoder_hidden_states_len : Optional [torch .Tensor ] = None ,
495496 ) -> torch .FloatTensor :
496497 if encoder_hidden_states is None :
497498 raise ValueError ("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)" )
@@ -537,16 +538,17 @@ def __call__(
537538
538539 # Concatenate for joint attention
539540 # Order: [text, image]
540- joint_query = torch .cat ([txt_query , img_query ], dim = 1 )
541- joint_key = torch .cat ([txt_key , img_key ], dim = 1 )
542- joint_value = torch .cat ([txt_value , img_value ], dim = 1 )
541+ joint_query = torch .cat ([img_query , txt_query ], dim = 1 )
542+ joint_key = torch .cat ([img_key , txt_key ], dim = 1 )
543+ joint_value = torch .cat ([img_value , txt_value ], dim = 1 )
543544
544545 # If an encoder_hidden_states_mask is provided, create a joint attention mask.
545546 # The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding.
546547 # We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend).
547548 # Only create the mask if there's actual padding, otherwise keep attention_mask=None for better SDPA performance.
549+ batch_size , image_seq_len = hidden_states .shape [:2 ]
550+ attention_kwargs = {}
548551 if encoder_hidden_states_mask is not None and attention_mask is None :
549- batch_size , image_seq_len = hidden_states .shape [:2 ]
550552 text_seq_len = encoder_hidden_states .shape [1 ]
551553
552554 if encoder_hidden_states_mask .shape [0 ] != batch_size :
@@ -568,7 +570,8 @@ def __call__(
568570 )
569571 # Create 2D joint mask [batch_size, text_seq_len + image_seq_len]
570572 # The attention dispatch will normalize this and extract sequence lengths
571- attention_mask = torch .cat ([text_attention_mask , image_attention_mask ], dim = 1 )
573+ attention_mask = torch .cat ([image_attention_mask , text_attention_mask ], dim = 1 )
574+ attention_kwargs ['seq_len' ] = [text_sample_len + image_seq_len for text_sample_len in encoder_hidden_states_len ]
572575
573576 # Compute joint attention
574577 joint_hidden_states = dispatch_attention_fn (
@@ -580,15 +583,16 @@ def __call__(
580583 is_causal = False ,
581584 backend = self ._attention_backend ,
582585 parallel_config = self ._parallel_config ,
586+ attention_kwargs = attention_kwargs ,
583587 )
584588
585589 # Reshape back
586590 joint_hidden_states = joint_hidden_states .flatten (2 , 3 )
587591 joint_hidden_states = joint_hidden_states .to (joint_query .dtype )
588592
589593 # Split attention outputs back
590- txt_attn_output = joint_hidden_states [:, :seq_txt , :] # Text part
591- img_attn_output = joint_hidden_states [:, seq_txt :, :] # Image part
594+ img_attn_output = joint_hidden_states [:, :image_seq_len , :] # Image part
595+ txt_attn_output = joint_hidden_states [:, image_seq_len :, :] # Text part
592596
593597 # Apply output projections
594598 img_attn_output = attn .to_out [0 ](img_attn_output )
@@ -694,6 +698,7 @@ def forward(
694698 encoder_hidden_states_mask : torch .Tensor ,
695699 temb : torch .Tensor ,
696700 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
701+ encoder_hidden_states_len : Optional [torch .Tensor ] = None ,
697702 joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
698703 modulate_index : Optional [List [int ]] = None ,
699704 ) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -728,6 +733,7 @@ def forward(
728733 encoder_hidden_states = txt_modulated , # Text stream (will be processed as "context")
729734 encoder_hidden_states_mask = encoder_hidden_states_mask ,
730735 image_rotary_emb = image_rotary_emb ,
736+ encoder_hidden_states_len = encoder_hidden_states_len ,
731737 ** joint_attention_kwargs ,
732738 )
733739
@@ -947,7 +953,9 @@ def forward(
947953 encoder_hidden_states = self .txt_in (encoder_hidden_states )
948954
949955 # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
950- text_seq_len , _ , encoder_hidden_states_mask = compute_text_seq_len_from_mask (
956+ if torch .all (encoder_hidden_states_mask ):
957+ encoder_hidden_states_mask = None
958+ text_seq_len , text_seq_len_per_sample , encoder_hidden_states_mask = compute_text_seq_len_from_mask (
951959 encoder_hidden_states , encoder_hidden_states_mask
952960 )
953961
@@ -971,6 +979,7 @@ def forward(
971979 encoder_hidden_states_mask ,
972980 temb ,
973981 image_rotary_emb ,
982+ text_seq_len_per_sample ,
974983 attention_kwargs ,
975984 modulate_index ,
976985 )
@@ -982,6 +991,7 @@ def forward(
982991 encoder_hidden_states_mask = encoder_hidden_states_mask ,
983992 temb = temb ,
984993 image_rotary_emb = image_rotary_emb ,
994+ encoder_hidden_states_len = text_seq_len_per_sample ,
985995 joint_attention_kwargs = attention_kwargs ,
986996 modulate_index = modulate_index ,
987997 )
0 commit comments