Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Sep 13, 2025

Summary

#163

  • Add support for mask and bias head dimensions: (B, H, Q, K), (B, H_k, Q, K), and (B, 1, Q, K).

Design

  • Flexible head indexing using h_mask/h_bias to select:
    • H-aligned: head = h
    • H_k-aligned: head = h % num_kv_heads
    • Broadcast: head = 0

Changes

  • Public/API:
    • Allow mask/bias shapes with head dim ∈ {1, num_kv_heads, num_heads}.
    • Params expose h_mask and h_bias to drive kernel-side head selection.
  • Kernel:
    • Unified head-index selection across fwd/bwd/splitKV/varlen.
    • Mask loads use the conversion copy path during
  • Python:
    • Interface accepts flexible head dims and applies broadcast semantics.
    • Padding path uses consistent, non-influential defaults for masked-out positions.
  • Misc:
    • Resolved signature/argument ordering in set_params_fprop/set_params_dgrad.

Implementation Notes

  • Conversion is fused with tiled copy; no extra passes.
  • Maintains vectorization where safe; clean tail handling.
  • SplitKV uses correct per-kv-head index for mask/bias.
  • Guarded to avoid undue register/shared memory pressure.

Tests

  • Shapes: head dims H, H_k, and 1; uniform and varlen sequences.
  • Dtypes: fp16/bf16 for Q/K/V and outputs.
  • Correctness: forward/backward equivalence and gradient checks.
  • Performance: microbenchmarks (latency/TFLOPs) across typical configs.
  • Integration: end-to-end for splitKV and varlen.

Docs

  • Document supported mask/bias shapes and broadcast semantics.
  • Note that conversion occurs in-kernel; no user changes required.
  • Update examples to show head-dim flexibility.

Checklist

  • Linked issue provided
  • API stable
  • Tests added or updated
  • Docs added or updated
  • No known performance regressions

Introduces h_mask and h_bias fields to track the number of heads
in attention mask and bias structures respectively.

Enables better head dimension management and validation in
flash attention operations.
Introduces dynamic head index calculation for mask and bias tensors to support different head configurations.

Previously used fixed head ratio calculations, now supports three scenarios:
- Single head broadcasting (h_mask/h_bias == 1)
- Multi-head with ratio-based indexing (h_mask/h_bias == h_k)
- Direct head indexing (fallback case)

Enables more flexible attention masking and bias application across different multi-head attention configurations.
Introduces conditional head index calculation for mask and bias operations based on tensor dimensions. Supports scenarios where mask/bias tensors can have single head (h=1), match key heads (h=h_k), or match query heads (h=h_q).

Replaces hardcoded head index division with dynamic selection logic that adapts to different tensor head configurations in flash attention backward kernel.
Adds support for mask and bias tensors with 1, num_heads_k, or num_heads dimensions instead of only num_heads_k.

Enables more flexible attention patterns by allowing masks and biases to be broadcast across different head configurations. Updates parameter passing to track separate head counts for masks and biases, and adds appropriate validation checks.

Temporarily disables variable-length attention variants to focus on core functionality improvements.
Clarifies that attention mask and bias parameters support multiple tensor shapes
to accommodate Multi-Query Attention (MQA) and Grouped Query Attention (GQA)
patterns, in addition to the standard multi-head attention format.

Adds explicit documentation for supported shapes including broadcast-compatible
dimensions for flexible attention implementations.
Clarifies that attention mask and bias tensors support multiple shape formats
to accommodate Multi-Query Attention (MQA) and Grouped-Query Attention (GQA)
patterns in addition to the standard multi-head attention format.

Adds explicit documentation for supported shapes: standard num_heads format,
num_kv_heads format, and broadcast-compatible single head format.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for flexible head dimensions in attention masks and biases, allowing for shapes (B, H, Q, K), (B, H_k, Q, K), and (B, 1, Q, K) instead of being restricted to only (B, H_k, Q, K).

  • Introduces kernel-side head indexing logic to handle different head dimension configurations
  • Updates API and parameter structures to track mask and bias head counts separately
  • Implements proper broadcast semantics and conversion paths for different head arrangements

Reviewed Changes

Copilot reviewed 6 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
flash_dmattn/integrations/flash_dynamic_mask_attention.py Updates documentation to reflect flexible mask/bias shape support
flash_dmattn/flash_dmattn_interface.py Updates API documentation for flexible head dimensions
csrc/flash_dmattn/src/flash_fwd_kernel.h Implements head indexing logic for mask/bias in forward kernel
csrc/flash_dmattn/src/flash_bwd_kernel.h Implements head indexing logic for mask/bias in backward kernel
csrc/flash_dmattn/src/flash.h Adds h_mask and h_bias fields to parameter structures
csrc/flash_dmattn/flash_api.cpp Updates function signatures and implements flexible head dimension handling

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k
at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k
Copy link

Copilot AI Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment format is inconsistent between mask and bias parameters. The bias comment has an extra comma after the first shape specification that should be removed to match the mask comment format.

Suggested change
at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k, or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k
at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k

Copilot uses AI. Check for mistakes.
const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &mask, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k
Copy link

Copilot AI Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of shape specifications in the bias comment differs from the mask comment above it. For consistency, the bias comment should list shapes in the same order: num_heads first, then num_heads_k, then 1.

Suggested change
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x num_heads x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k
const at::Tensor &bias, // batch_size x num_heads x seqlen_q x seqlen_k or batch_size x num_heads_k x seqlen_q x seqlen_k or batch_size x 1 x seqlen_q x seqlen_k

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit e854654 into main Sep 13, 2025
3 of 4 checks passed
@LoserCheems LoserCheems deleted the support-all-shape-of-mask/bias branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants