Skip to content

[FEATURE REQUEST] Support {num_heads, num_kv_heads, 1} shaped bias in attention functions #189

@LoserCheems

Description

@LoserCheems

Problem statement

Currently, Flash-DMATTN supports learnable bias tensors (attn_bias) with shape (batch_size, {1|num_kv_heads|num_heads}, {seqlen_q|0}, seqlen_k) in the standard flash_dmattn_func, which is a significant advantage over vanilla Flash Attention for models like GPT-OSS that require learnable attention biases.

However, there's a limitation: the bias head dimension cannot be broadcast over the sequence dimension. Specifically, bias shapes like (batch_size, num_heads, 1, 1) or (batch_size, num_kv_heads, 1, 1) are not currently supported, even though they are valid broadcasting patterns in PyTorch and would enable per-head scalar biases without materializing full attention matrices.

This limitation becomes particularly problematic when:

  1. Implementing attention temperature scaling per head (common in multi-task models)
  2. Using learnable head-wise gating mechanisms (e.g., attention sink tokens with per-head weights)
  3. Migrating from SDPA or other backends that support arbitrary bias broadcasting
  4. Working with variable-length sequences where flash_dmattn_varlen_func currently has no bias support at all

Proposed solution

Extend both flash_dmattn_func and flash_dmattn_varlen_func to support bias tensors with sequence dimensions of size 1, enabling PyTorch-style broadcasting:

For flash_dmattn_func:

  • Current: (B, {1|H_k|H}, {L_q|0}, L_k) where sequence dims must match or be 0
  • Desired: (B, {1|H_k|H}, {L_q|1|0}, {L_k|1}) allowing sequence-dimension broadcasting

For flash_dmattn_varlen_func:

  • Current: No bias support (removed in recent refactor)
  • Desired: (B, {1|H_k|H}, 1, 1) support for per-head scalar biases
    • Note: Full (L_q, L_k) biases in varlen mode may be deferred due to ragged indexing complexity

Example use cases:

# Per-head temperature scaling (common in mixture-of-experts attention)
head_temp = torch.randn(batch_size, num_heads, 1, 1, device='cuda')
out = flash_dmattn_func(q, k, v, attn_bias=head_temp)

# Attention sink with learnable per-head weights
sink_weight = nn.Parameter(torch.ones(1, num_heads, 1, 1))  # Shape: (1, H, 1, 1)
sink_bias = sink_weight.expand(batch_size, -1, 1, seqlen_k)
out = flash_dmattn_func(q, k, v, attn_bias=sink_bias)

# Variable-length with per-head gating (currently impossible)
gate = torch.sigmoid(gate_layer(hidden_states))  # (B, H, 1, 1)
out = flash_dmattn_varlen_func(
    q, k, v, cu_seqlens_q, cu_seqlens_k, max_q, max_k,
    attn_bias=gate  # Requested feature
)

Alternatives considered

  1. Materialize full bias tensors

    • Works but wastes memory: (B, H, L_q, L_k) for a per-head scalar is O(BHL²) vs O(B*H)
    • Defeats the purpose of Flash Attention's memory efficiency
    • Example: For B=32, H=32, L=4096, this is 3232409640962 bytes = 32 GB vs 32324 bytes = 4 KB
  2. Apply bias outside Flash Attention

    • Requires materializing attention matrix (defeats Flash's memory optimization)
    • Breaks gradient flow through fused kernel
    • Significantly slower (requires additional matrix operations)
  3. Use custom preprocessing to expand bias

    • Still materializes full tensors, wasting memory
    • Adds preprocessing overhead
    • Doesn't solve varlen case where sequence boundaries are ragged
  4. Fall back to SDPA for models needing this feature

    • Loses Flash-DMATTN's performance benefits (dynamic mask support, better kernel tuning)

Implementation details

No response

Use case

No response

Related work

No response

Additional context

No response

Metadata

Metadata

Labels

featureNew feature request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions