@@ -560,16 +560,15 @@ def __call__(
560560 f"must match encoder_hidden_states sequence length ({ text_seq_len } )."
561561 )
562562
563- # Only create mask if there's actual padding (i.e., some False/0 values)
564- # When all values are True/1.0, passing attention_mask=None is more efficient for SDPA
563+ # Create joint attention mask
564+ # torch.compile compatible: always create mask when encoder_hidden_states_mask is provided
565565 text_attention_mask = encoder_hidden_states_mask .bool ()
566- if not text_attention_mask .all ():
567- image_attention_mask = torch .ones (
568- (batch_size , image_seq_len ), dtype = torch .bool , device = hidden_states .device
569- )
570- # Create 2D joint mask [batch_size, text_seq_len + image_seq_len]
571- # The attention dispatch will normalize this and extract sequence lengths
572- attention_mask = torch .cat ([text_attention_mask , image_attention_mask ], dim = 1 )
566+ image_attention_mask = torch .ones (
567+ (batch_size , image_seq_len ), dtype = torch .bool , device = hidden_states .device
568+ )
569+ # Create 2D joint mask [batch_size, text_seq_len + image_seq_len]
570+ # The attention dispatch will normalize this and extract sequence lengths
571+ attention_mask = torch .cat ([text_attention_mask , image_attention_mask ], dim = 1 )
573572
574573 # Compute joint attention
575574 joint_hidden_states = dispatch_attention_fn (
0 commit comments