[BUG FIX] Improve packed GQA handling#224
Conversation
Improves pointer setup for packed grouped-query layouts to support efficient tiling and reduce indexing overhead
Improves packed GQA handling to fix head indexing, masking, and storage for correctness across variable lengths and local/causal settings
There was a problem hiding this comment.
Pull request overview
This pull request aims to fix bugs in the packed grouped-query attention (GQA) implementation, specifically addressing issues with head indexing, sequence length masking, and pointer construction.
Changes:
- Added new
make_pack_gqa_ptrsfunction in seqlen_info.py for constructing pointers in packed GQA layout - Modified flash_fwd.py to support pack_gqa mode with conditional pointer construction and masking logic
- Enhanced masking to properly handle variable-length sequences with packed GQA by adding explicit seqlen masking for the first iteration when not using causal/local attention
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| flash_sparse_attn/ops/triton/seqlen_info.py | Adds make_pack_gqa_ptrs function for computing memory pointers in packed GQA layout |
| flash_sparse_attn/ops/triton/flash_fwd.py | Implements packed GQA support with conditional logic for pointer construction, masking, and grid computation; fixes window_size checking and adds first-iteration masking for non-causal/non-local cases |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| is_local = window_size is not None | ||
|
|
||
| is_local = window_size[0] is not None or window_size[1] is not None |
There was a problem hiding this comment.
This line will raise a TypeError if window_size is None. The original code checked if window_size is not None before indexing. Consider adding a check: is_local = window_size is not None and (window_size[0] is not None or window_size[1] is not None)
| is_local = window_size[0] is not None or window_size[1] is not None | |
| is_local = window_size is not None and ( | |
| window_size[0] is not None or window_size[1] is not None | |
| ) |
| TILE_M: tl.constexpr, | ||
| TILE_K: tl.constexpr, | ||
| QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, | ||
| ): |
There was a problem hiding this comment.
The new function make_pack_gqa_ptrs lacks documentation. Consider adding a docstring explaining its purpose, parameters, and return value, similar to other functions in this module like get_seqlen_info.
| ): | |
| ): | |
| """ | |
| Compute pointers into a packed GQA layout for a block of query positions and heads. | |
| This helper builds a block of pointers suitable for Triton loads when queries are | |
| laid out in a "packed GQA" format, where multiple query heads share a single KV | |
| head. The M-dimension is interpreted as interleaving query heads within each KV | |
| head according to ``QHEADS_PER_KVHEAD_PACKGQA``. | |
| Args: | |
| base_ptrs: Base pointer (or tensor of base pointers) to the start of the | |
| sequence for the current batch/sample. | |
| m_block: Index of the current M block along the sequence/head-merged dimension. | |
| head_idx: Index of the KV head for which we are generating query pointers. | |
| stride_head: Stride between consecutive heads in memory. | |
| stride_seq: Stride between consecutive sequence positions in memory. | |
| TILE_M (tl.constexpr): Block size in the M dimension. | |
| TILE_K (tl.constexpr): Block size in the K dimension (e.g., vector width). | |
| QHEADS_PER_KVHEAD_PACKGQA (tl.constexpr): Number of query heads packed per KV head. | |
| Returns: | |
| A tensor of pointers suitable for Triton loads: | |
| * If ``TILE_K > 1``, returns a 2D tensor of shape ``[TILE_M, TILE_K]`` where | |
| each row corresponds to a (sequence position, query head) pair and each | |
| column to a K-element within the vector. | |
| * If ``TILE_K == 1``, returns a 1D tensor of shape ``[TILE_M]`` with one | |
| pointer per (sequence position, query head) pair. | |
| """ |
| cu_seqlens_k: Optional[torch.Tensor] = None, | ||
| max_seqlen_q: Optional[int] = None, | ||
| max_seqlen_k: Optional[int] = None, | ||
| pack_gqa: bool = False, |
There was a problem hiding this comment.
The PR description claims "Added tests to validate the correctness of packed GQA handling" but no test files are included in the changes, and the checklist shows "Adds or updates tests" is unchecked. The new pack_gqa functionality lacks test coverage. Consider adding tests that verify the correctness of packed GQA handling with various sequence lengths and local/causal settings as mentioned in the PR description.
| m_block, | ||
| head_idx, | ||
| stride_lh, | ||
| stride_lb, |
There was a problem hiding this comment.
For packed GQA in non-varlen case, stride_lb (the batch stride for LSE) is being passed as stride_seq, but the correct value should be 1 (the stride of the innermost dimension of LSE). The LSE tensor has shape (batch, num_heads_q, seqlen), so after offset_batch_Q returns lse_base pointing to lse[batch_idx, 0, 0], accessing lse[batch_idx, q_head, m_idx] requires offset q_head * seqlen + m_idx * 1. The make_pack_gqa_ptrs function expects stride_seq to be the stride for the m_idx dimension, which is 1 for LSE, not stride_lb.
| stride_lb, | |
| 1, |
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist