Skip to content

Commit 7c15f9a

Browse files
akaitsuki-iiGlaceon-Hyy
authored andcommitted
fix pos_ids and casual_mask of qwen2_5_vl
1 parent 77cb3d5 commit 7c15f9a

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_vision_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
15,
1616
23,
1717
31
18-
]
18+
],
19+
"attn_impl": "sdpa"
1920
}

diffsynth_engine/models/qwen_image/qwen2_5_vl.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,15 @@ def forward(
550550
if attention_mask is not None: # no matter the length, we just slice it
551551
causal_mask = attention_mask[:, :, :, : key_states.shape[1]]
552552

553-
# TODO: use is_causal when attention mask is causal
554553
if self.attn_impl == "sdpa":
555-
out = attention_ops.sdpa_attn(query_states, key_states, value_states, is_causal=True)
554+
is_causal = causal_mask is None and query_states.shape[1] > 1
555+
out = attention_ops.sdpa_attn(
556+
query_states,
557+
key_states,
558+
value_states,
559+
attn_mask=causal_mask,
560+
is_causal=is_causal,
561+
)
556562
else:
557563
# TODO: attention_mask for flash attention 2
558564
out = attention_ops.attention(
@@ -783,6 +789,9 @@ def _update_causal_mask(
783789
else past_seen_tokens + sequence_length + 1
784790
)
785791

792+
if self.config.attn_impl == "sdpa" and self._ignore_causal_mask(attention_mask, input_tensor, past_seen_tokens):
793+
return None
794+
786795
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
787796
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
788797
attention_mask=attention_mask,
@@ -802,6 +811,32 @@ def _update_causal_mask(
802811

803812
return causal_mask
804813

814+
@staticmethod
815+
def _ignore_causal_mask(
816+
attention_mask: Optional[torch.Tensor],
817+
inputs_embeds: torch.Tensor,
818+
past_seen_tokens: int,
819+
):
820+
"""
821+
Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
822+
ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
823+
824+
In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
825+
`key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
826+
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
827+
passed).
828+
"""
829+
830+
_, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
831+
key_value_length = query_length + past_seen_tokens
832+
833+
ignore_causal_mask = False
834+
if (attention_mask is None or (attention_mask.ndim == 2 and torch.all(attention_mask == 1))) and (
835+
query_length == 1 or key_value_length == query_length
836+
):
837+
ignore_causal_mask = True
838+
return ignore_causal_mask
839+
805840
@staticmethod
806841
def _prepare_4d_causal_attention_mask_with_cache_position(
807842
attention_mask: torch.Tensor,
@@ -825,8 +860,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
825860
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
826861
dtype (`torch.dtype`):
827862
The dtype to use for the 4D attention mask.
828-
past_key_values (`Cache`):
829-
The cache class that is being used currently to generate
830863
cache_position (`torch.LongTensor`):
831864
Indices depicting the position of the input sequence tokens in the sequence.
832865
batch_size (`int`):
@@ -1199,14 +1232,29 @@ def forward(
11991232
if position_ids is None:
12001233
assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D"
12011234
# calculate RoPE index once per generation in the pre-fill stage only
1202-
position_ids, rope_deltas = self.get_rope_index(
1203-
input_ids,
1204-
image_grid_thw,
1205-
video_grid_thw,
1206-
second_per_grid_ts,
1207-
attention_mask,
1235+
is_prefill = (
1236+
(cache_position is None or cache_position[0] == 0)
1237+
or self.rope_deltas is None
1238+
or past_key_values is None
12081239
)
1209-
self.rope_deltas = rope_deltas
1240+
if is_prefill:
1241+
position_ids, rope_deltas = self.get_rope_index(
1242+
input_ids,
1243+
image_grid_thw,
1244+
video_grid_thw,
1245+
second_per_grid_ts,
1246+
attention_mask,
1247+
)
1248+
self.rope_deltas = rope_deltas
1249+
# then use the prev pre-calculated rope-deltas to get the correct position ids
1250+
else:
1251+
batch_size, seq_length, _ = inputs_embeds.shape
1252+
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1253+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1254+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1255+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1256+
position_ids = position_ids.add(delta)
1257+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
12101258

12111259
hidden_states, present_key_values = self.model(
12121260
input_ids=None,

0 commit comments

Comments
 (0)