Skip to content

Commit a61db8c

Browse files
committed
Refines documentation for attention_mask and attention_bias parameters in flash_dynamic_mask_attention_forward
1 parent ca66b6c commit a61db8c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward(
2929
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3030
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3131
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
32-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
33-
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
32+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len).
33+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, {num_heads|num_kv_heads|1}, query_len, key_len), if attention_mask is None, also supports (batch_size, {num_heads|num_kv_heads|1}, key_len).
3434
scaling (Optional[float]): The scaling factor for the attention scores.
3535
softcap (Optional[float]): The softcap value for the attention scores.
3636
**kwargs: Additional keyword arguments.

0 commit comments

Comments
 (0)