Skip to content

Commit 42e118d

Browse files
committed
Fix attention_mask and attention_bias shape descriptions in docstring
1 parent 502a1a4 commit 42e118d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def flash_dynamic_mask_attention_forward(
3030
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3131
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3232
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
33-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34-
(batch_size, seq_len) or (batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
33+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
34+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
3535
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape
36-
(batch_size, {num_heads|num_kv_heads|1}, {query_len|0}, key_len).
36+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
3737
scaling (Optional[float]): The scaling factor for the attention scores.
3838
window_size (Optional[int]): The size of the window to keep.
3939
softcap (Optional[float]): The softcap value for the attention scores.

0 commit comments

Comments
 (0)