-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Problem statement
Currently, Flash-DMATTN supports learnable bias tensors (attn_bias) with shape (batch_size, {1|num_kv_heads|num_heads}, {seqlen_q|0}, seqlen_k) in the standard flash_dmattn_func, which is a significant advantage over vanilla Flash Attention for models like GPT-OSS that require learnable attention biases.
However, there's a limitation: the bias head dimension cannot be broadcast over the sequence dimension. Specifically, bias shapes like (batch_size, num_heads, 1, 1) or (batch_size, num_kv_heads, 1, 1) are not currently supported, even though they are valid broadcasting patterns in PyTorch and would enable per-head scalar biases without materializing full attention matrices.
This limitation becomes particularly problematic when:
- Implementing attention temperature scaling per head (common in multi-task models)
- Using learnable head-wise gating mechanisms (e.g., attention sink tokens with per-head weights)
- Migrating from SDPA or other backends that support arbitrary bias broadcasting
- Working with variable-length sequences where
flash_dmattn_varlen_funccurrently has no bias support at all
Proposed solution
Extend both flash_dmattn_func and flash_dmattn_varlen_func to support bias tensors with sequence dimensions of size 1, enabling PyTorch-style broadcasting:
For flash_dmattn_func:
- Current:
(B, {1|H_k|H}, {L_q|0}, L_k)where sequence dims must match or be 0 - Desired:
(B, {1|H_k|H}, {L_q|1|0}, {L_k|1})allowing sequence-dimension broadcasting
For flash_dmattn_varlen_func:
- Current: No bias support (removed in recent refactor)
- Desired:
(B, {1|H_k|H}, 1, 1)support for per-head scalar biases- Note: Full
(L_q, L_k)biases in varlen mode may be deferred due to ragged indexing complexity
- Note: Full
Example use cases:
# Per-head temperature scaling (common in mixture-of-experts attention)
head_temp = torch.randn(batch_size, num_heads, 1, 1, device='cuda')
out = flash_dmattn_func(q, k, v, attn_bias=head_temp)
# Attention sink with learnable per-head weights
sink_weight = nn.Parameter(torch.ones(1, num_heads, 1, 1)) # Shape: (1, H, 1, 1)
sink_bias = sink_weight.expand(batch_size, -1, 1, seqlen_k)
out = flash_dmattn_func(q, k, v, attn_bias=sink_bias)
# Variable-length with per-head gating (currently impossible)
gate = torch.sigmoid(gate_layer(hidden_states)) # (B, H, 1, 1)
out = flash_dmattn_varlen_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_q, max_k,
attn_bias=gate # Requested feature
)Alternatives considered
-
Materialize full bias tensors
- Works but wastes memory:
(B, H, L_q, L_k)for a per-head scalar is O(BHL²) vs O(B*H) - Defeats the purpose of Flash Attention's memory efficiency
- Example: For B=32, H=32, L=4096, this is 3232409640962 bytes = 32 GB vs 32324 bytes = 4 KB
- Works but wastes memory:
-
Apply bias outside Flash Attention
- Requires materializing attention matrix (defeats Flash's memory optimization)
- Breaks gradient flow through fused kernel
- Significantly slower (requires additional matrix operations)
-
Use custom preprocessing to expand bias
- Still materializes full tensors, wasting memory
- Adds preprocessing overhead
- Doesn't solve varlen case where sequence boundaries are ragged
-
Fall back to SDPA for models needing this feature
- Loses Flash-DMATTN's performance benefits (dynamic mask support, better kernel tuning)
Implementation details
No response
Use case
No response
Related work
No response
Additional context
No response