Skip to content

Commit a2b5309

Browse files
committed
Normalize shape notation in create_mask docstrings (use key_len instead of {key_len|1})
1 parent 7d4cf23 commit a2b5309

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

flash_dmattn/utils/mask.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,9 @@ def create_mask(
6565
6666
Args:
6767
attention_bias (torch.Tensor): The attention bias tensor of shape
68-
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
69-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
70-
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
68+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
69+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape
70+
(batch_size, seq_len) or ({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
7171
batch_size (int): The batch size.
7272
query_len (int): The sequence length of the query.
7373
key_len (int): The sequence length of the key.
@@ -76,7 +76,7 @@ def create_mask(
7676
7777
Returns:
7878
attention (Tensor): The attention mask tensor of shape
79-
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, {key_len|1}).
79+
({batch_size|1}, {num_heads|num_kv_heads|1}, {query_len|1}, key_len).
8080
"""
8181

8282
# If attention_mask is of shape (batch_size, seq_len), reshape it to (batch_size, 1, 1, key_len)

0 commit comments

Comments
 (0)