Skip to content

Commit 9cf0f04

Browse files
committed
Updates attention mask and bias documentation for MQA/GQA
Clarifies that attention mask and bias parameters support multiple tensor shapes to accommodate Multi-Query Attention (MQA) and Grouped Query Attention (GQA) patterns, in addition to the standard multi-head attention format. Adds explicit documentation for supported shapes including broadcast-compatible dimensions for flexible attention implementations.
1 parent 1050261 commit 9cf0f04

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

flash_dmattn/flash_dmattn_interface.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,14 @@ def flash_dmattn_func(
361361
key: torch.Tensor. The key tensor of shape (batch_size, seqlen, nheads_k, headdim)
362362
value: torch.Tensor. The value tensor of shape (batch_size, seqlen, nheads_k, headdim)
363363
attn_mask: torch.Tensor, optional. The attention mask boolean tensor of
364-
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to apply to the attention scores.
364+
shape (batch_size, nheads, seqlen_q, seqlen_k) to apply to the attention scores.
365+
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
366+
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
365367
If None, no mask is applied.
366368
attn_bias: torch.Tensor, optional. The attention bias float tensor of
367-
shape (batch_size, nheads_k, seqlen_q, seqlen_k) to add to the attention scores.
369+
shape (batch_size, nheads, seqlen_q, seqlen_k) to add to the attention scores.
370+
Also supports shape (batch_size, nheads_k, seqlen_q, seqlen_k) or
371+
(batch_size, 1, seqlen_q, seqlen_k) for MQA/GQA.
368372
If None, no bias is applied.
369373
is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
370374
scale: float. The scaling of QK^T before applying softmax.

0 commit comments

Comments
 (0)