Skip to content

Commit 3d91162

Browse files
authored
Merge pull request #192 from SmallDoges/fix-189
Fix attention_mask and attention_bias shape descriptions and remove redundant checks
2 parents 4dd087d + 623b75d commit 3d91162

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def flash_dynamic_mask_attention_forward(
3030
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3131
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3232
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
33-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34-
(batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
33+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
3535
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
36-
(batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
36+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
3737
scaling (Optional[float]): The scaling factor for the attention scores.
3838
window_size (Optional[int]): The size of the window to keep.
3939
softcap (Optional[float]): The softcap value for the attention scores.

flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,6 @@ def _flash_dynamic_mask_attention_forward(
626626
):
627627
min_dtype = torch.finfo(query_states.dtype).min
628628
if attention_mask is not None:
629-
if attention_mask.dim() == 4 and attention_bias.dim() == 3:
630-
attention_bias = attention_bias.unsqueeze(-2).expand(-1, -1, query_length, -1)
631-
if attention_mask.dim() == 3 and attention_bias.dim() == 4:
632-
attention_mask = attention_mask.unsqueeze(-2).expand(-1, -1, query_length, -1)
633-
634629
topk_values, topk_indices = torch.topk(
635630
attention_bias.masked_fill(~attention_mask, min_dtype).detach(),
636631
window_size, dim=-1, largest=True, sorted=False

0 commit comments

Comments
 (0)