Skip to content

Refactor qheads_per_kvhead calculations for clarity#234

Merged
LoserCheems merged 1 commit intomainfrom
optim-triton-version
Mar 8, 2026
Merged

Refactor qheads_per_kvhead calculations for clarity#234
LoserCheems merged 1 commit intomainfrom
optim-triton-version

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

  • Improve clarity and consistency in the calculation of qheads_per_kvhead in forward functions.

Root Cause

  • The previous implementation had redundant calculations that could lead to confusion.

Changes

  • Refactored the calculation of qheads_per_kvhead and introduced a separate variable for the packed case.

Reproduction

  • No specific bug to reproduce; this is a code clarity improvement.

Tests

  • Existing tests validate the functionality; no new tests required.

Compatibility

  • No backward compatibility issues.

Checklist

Copilot AI review requested due to automatic review settings March 8, 2026 09:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the qheads_per_kvhead variable calculations in the four forward dispatch functions within flash_fwd.py to improve clarity. Previously, a single variable qheads_per_kvhead was conditionally set based on pack_gqa and then also re-computed inline where the unconditional value was needed. The refactoring splits this into two clearly named variables.

Changes:

  • Split the single qheads_per_kvhead variable into two: qheads_per_kvhead (always the true ratio num_heads_q // num_heads_kv) and qheads_per_kvhead_packgqa (conditionally 1 when pack_gqa is False).
  • Replaced the inline num_heads_q // num_heads_kv expression at each kernel call site with the pre-computed qheads_per_kvhead variable.
  • Applied the same refactoring consistently across all four forward functions: _flash_attn_base_forward, _flash_attn_varlen_base_forward, _flash_attn_sm90_forward, and _flash_attn_varlen_sm90_forward.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LoserCheems LoserCheems merged commit cbccfab into main Mar 8, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants