Skip to content

Commit ccfd3ec

Browse files
committed
Updates attention tensor shape documentation for MQA/GQA
Clarifies that attention mask and bias tensors support multiple shape formats to accommodate Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) patterns in addition to the standard multi-head attention format. Adds explicit documentation for supported shapes: standard num_heads format, num_kv_heads format, and broadcast-compatible single head format.
1 parent 9cf0f04 commit ccfd3ec

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flash_dmattn/integrations/flash_dynamic_mask_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def flash_dynamic_mask_attention_forward(
2929
query (torch.Tensor): The query tensor of shape (batch_size, num_heads, query_len, head_dim).
3030
key (torch.Tensor): The key tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
3131
value (torch.Tensor): The value tensor of shape (batch_size, num_kv_heads, key_len, head_dim).
32-
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_kv_heads, query_len, key_len).
33-
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_kv_heads, query_len, key_len).
32+
attention_mask (Optional[torch.Tensor]): The attention mask boolean tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
33+
attention_bias (Optional[torch.Tensor]): The attention bias float tensor of shape (batch_size, num_heads, query_len, key_len), also supports (batch_size, num_kv_heads, query_len, key_len) or (batch_size, 1, query_len, key_len) for MQA/GQA.
3434
scaling (Optional[float]): The scaling factor for the attention scores.
3535
softcap (Optional[float]): The softcap value for the attention scores.
3636
**kwargs: Additional keyword arguments.

0 commit comments

Comments
 (0)