Skip to content

Commit 268e657

Browse files
committed
Fixes attention bias handling in _flash_dynamic_mask_attention_forward for 4D attention masks
1 parent da1bf26 commit 268e657

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def _flash_dynamic_mask_attention_forward(
8585
query_states, key_states, value_states, attention_bias, target_dtype
8686
)
8787

88-
if attention_mask is not None:
88+
if attention_mask is not None and attention_mask.dim() == 4:
89+
if attention_bias.dim() == 3:
90+
attention_bias = attention_bias.unsqueeze(-2)
8991
attention_bias = attention_bias.masked_fill(
9092
~attention_mask,
9193
min_dtype
@@ -98,7 +100,6 @@ def _flash_dynamic_mask_attention_forward(
98100
)
99101
attention_mask = torch.zeros_like(attention_bias, dtype=torch.bool, device=attention_bias.device)
100102
attention_mask = attention_mask.scatter(-1, topk_indices, topk_values != min_dtype)
101-
attention_bias = attention_bias.masked_fill(~attention_mask, min_dtype)
102103

103104
out = flash_fn(
104105
query_states, key_states, value_states, attn_mask=attention_mask, attn_bias=attention_bias, scale=softmax_scale, is_causal=is_causal

0 commit comments

Comments
 (0)