Skip to content

[BUG FIX] Improve packed GQA handling#224

Merged
LoserCheems merged 2 commits intomainfrom
fsa
Jan 19, 2026
Merged

[BUG FIX] Improve packed GQA handling#224
LoserCheems merged 2 commits intomainfrom
fsa

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

  • Fixes issues with head indexing, masking, and storage in packed grouped-query layouts.

Root Cause

  • The previous implementation did not correctly handle variable lengths and local/causal settings for packed GQA.

Changes

  • Enhanced pointer construction and setup for packed GQA to support efficient tiling and reduce indexing overhead.

Reproduction

  • Test with various sequence lengths and local/causal settings to ensure correct behavior.

Tests

  • Added tests to validate the correctness of packed GQA handling.

Compatibility

  • No breaking changes introduced.

Checklist

Improves pointer setup for packed grouped-query layouts to support efficient tiling and reduce indexing overhead
Improves packed GQA handling to fix head indexing, masking, and storage for correctness across variable lengths and local/causal settings
Copilot AI review requested due to automatic review settings January 19, 2026 04:53
@LoserCheems LoserCheems merged commit 3f08a7d into main Jan 19, 2026
5 checks passed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request aims to fix bugs in the packed grouped-query attention (GQA) implementation, specifically addressing issues with head indexing, sequence length masking, and pointer construction.

Changes:

  • Added new make_pack_gqa_ptrs function in seqlen_info.py for constructing pointers in packed GQA layout
  • Modified flash_fwd.py to support pack_gqa mode with conditional pointer construction and masking logic
  • Enhanced masking to properly handle variable-length sequences with packed GQA by adding explicit seqlen masking for the first iteration when not using causal/local attention

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
flash_sparse_attn/ops/triton/seqlen_info.py Adds make_pack_gqa_ptrs function for computing memory pointers in packed GQA layout
flash_sparse_attn/ops/triton/flash_fwd.py Implements packed GQA support with conditional logic for pointer construction, masking, and grid computation; fixes window_size checking and adds first-iteration masking for non-causal/non-local cases

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

is_local = window_size is not None

is_local = window_size[0] is not None or window_size[1] is not None
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will raise a TypeError if window_size is None. The original code checked if window_size is not None before indexing. Consider adding a check: is_local = window_size is not None and (window_size[0] is not None or window_size[1] is not None)

Suggested change
is_local = window_size[0] is not None or window_size[1] is not None
is_local = window_size is not None and (
window_size[0] is not None or window_size[1] is not None
)

Copilot uses AI. Check for mistakes.
TILE_M: tl.constexpr,
TILE_K: tl.constexpr,
QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr,
):
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function make_pack_gqa_ptrs lacks documentation. Consider adding a docstring explaining its purpose, parameters, and return value, similar to other functions in this module like get_seqlen_info.

Suggested change
):
):
"""
Compute pointers into a packed GQA layout for a block of query positions and heads.
This helper builds a block of pointers suitable for Triton loads when queries are
laid out in a "packed GQA" format, where multiple query heads share a single KV
head. The M-dimension is interpreted as interleaving query heads within each KV
head according to ``QHEADS_PER_KVHEAD_PACKGQA``.
Args:
base_ptrs: Base pointer (or tensor of base pointers) to the start of the
sequence for the current batch/sample.
m_block: Index of the current M block along the sequence/head-merged dimension.
head_idx: Index of the KV head for which we are generating query pointers.
stride_head: Stride between consecutive heads in memory.
stride_seq: Stride between consecutive sequence positions in memory.
TILE_M (tl.constexpr): Block size in the M dimension.
TILE_K (tl.constexpr): Block size in the K dimension (e.g., vector width).
QHEADS_PER_KVHEAD_PACKGQA (tl.constexpr): Number of query heads packed per KV head.
Returns:
A tensor of pointers suitable for Triton loads:
* If ``TILE_K > 1``, returns a 2D tensor of shape ``[TILE_M, TILE_K]`` where
each row corresponds to a (sequence position, query head) pair and each
column to a K-element within the vector.
* If ``TILE_K == 1``, returns a 1D tensor of shape ``[TILE_M]`` with one
pointer per (sequence position, query head) pair.
"""

Copilot uses AI. Check for mistakes.
cu_seqlens_k: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_k: Optional[int] = None,
pack_gqa: bool = False,
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description claims "Added tests to validate the correctness of packed GQA handling" but no test files are included in the changes, and the checklist shows "Adds or updates tests" is unchecked. The new pack_gqa functionality lacks test coverage. Consider adding tests that verify the correctness of packed GQA handling with various sequence lengths and local/causal settings as mentioned in the PR description.

Copilot uses AI. Check for mistakes.
m_block,
head_idx,
stride_lh,
stride_lb,
Copy link

Copilot AI Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For packed GQA in non-varlen case, stride_lb (the batch stride for LSE) is being passed as stride_seq, but the correct value should be 1 (the stride of the innermost dimension of LSE). The LSE tensor has shape (batch, num_heads_q, seqlen), so after offset_batch_Q returns lse_base pointing to lse[batch_idx, 0, 0], accessing lse[batch_idx, q_head, m_idx] requires offset q_head * seqlen + m_idx * 1. The make_pack_gqa_ptrs function expects stride_seq to be the stride for the m_idx dimension, which is 1 for LSE, not stride_lb.

Suggested change
stride_lb,
1,

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants