Skip to content

Conversation

@slowlyC
Copy link

@slowlyC slowlyC commented Jan 18, 2026

Summary

When USE_G is enabled in chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64, multiplying b_q by b_g_exp[None, :] causes the data type to change from the original dtype (e.g., bfloat16) to float32 (b_g's dtype).

This PR adds .to(b_q.dtype) to preserve the original data type after the multiplication, ensuring correct dtype in the subsequent tl.dot operations.

Changes

  • Add .to(b_q.dtype) after b_q * b_g_exp[None, :] for all K dimension blocks (b_dh1-b_dh4)
  • Remove redundant .to(b_q.dtype) in tl.dot calls since b_q now has the correct dtype

Summary by CodeRabbit

  • Bug Fixes
    • Improved data type consistency in mixed-precision operations to reduce potential computation errors.

✏️ Tip: You can customize this high-level summary in your review settings.

…o change to b_g.dtype, and it needs to be converted back to b_q.dtype.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 18, 2026

Walkthrough

The file fla/ops/common/chunk_delta_h.py was updated to improve dtype handling consistency. After elementwise multiplications with per-block scaling, the result is now explicitly cast to the original b_q dtype. Redundant dtype casts in subsequent dot-product computations were removed, ensuring dtype consistency without implicit upcasting while maintaining the same control flow.

Changes

Cohort / File(s) Summary
Dtype Consistency Improvements
fla/ops/common/chunk_delta_h.py
Added explicit dtype casting after b_g_exp multiplication to maintain original b_q dtype; removed redundant .to(b_q.dtype) casts in tl.dot operations, relying on consistent dtype tracking throughout computation

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Poem

🐰 A rabbit hops through dtypes with care,
Casting with precision, beyond compare,
No redundant paths, just clean and true,
Mixed-precision woes? Now we're through! ✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies the specific fix (dtype conversion of b_q), the context (bwd_dhu kernel), and the condition (when USE_G is enabled), accurately summarizing the main change.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @slowlyC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical data type issue within the chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64 kernel. When the USE_G flag is active, an intermediate multiplication operation was inadvertently promoting the b_q tensor's data type to float32, potentially leading to precision errors in gradient computations. The changes ensure that b_q maintains its intended data type throughout the kernel, thereby preserving numerical stability and correctness for backward pass operations.

Highlights

  • Data Type Preservation: Ensures the b_q tensor retains its original data type (e.g., bfloat16) after multiplication with b_g_exp when USE_G is enabled, preventing unintended float32 promotion.
  • Redundant Cast Removal: Eliminates unnecessary .to(b_q.dtype) calls in subsequent tl.dot operations, as b_q now consistently holds the correct data type.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a data type issue in the chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64 kernel where b_q was being unintentionally up-casted. The fix, which explicitly casts b_q back to its original dtype, is sound. The related cleanup of removing redundant casts in tl.dot is also a good improvement. I've added one suggestion to make the type casting even more explicit and robust. Overall, a solid fix.

if USE_G:
b_dh1 *= bg_last_exp
b_q = b_q * b_g_exp[None, :]
b_q = (b_q * b_g_exp[None, :]).to(b_q.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved clarity and robustness, it's better to explicitly use the element type of the source tensor q for casting. This avoids any potential ambiguity about what b_q.dtype refers to during reassignment and makes it clear that the intent is to cast back to the original data type of q.

This suggestion also applies to the similar modifications on lines 414, 428, and 442.

Suggested change
b_q = (b_q * b_g_exp[None, :]).to(b_q.dtype)
b_q = (b_q * b_g_exp[None, :]).to(q.dtype.element_ty)

@slowlyC slowlyC changed the title Fix b_q dtype in bwd_dhu kernel when USE_G is enabled [Fix] b_q dtype in bwd_dhu kernel when USE_G is enabled Jan 18, 2026
@slowlyC slowlyC changed the title [Fix] b_q dtype in bwd_dhu kernel when USE_G is enabled [Fix][GDN] b_q dtype in bwd_dhu kernel when USE_G is enabled Jan 18, 2026
@slowlyC slowlyC changed the title [Fix][GDN] b_q dtype in bwd_dhu kernel when USE_G is enabled [Fix][GDN] convert b_q dtype in bwd_dhu kernel when USE_G is enabled Jan 18, 2026
@zhiyuan1i zhiyuan1i added help wanted Extra attention is needed and removed help wanted Extra attention is needed labels Jan 18, 2026
@zhiyuan1i
Copy link
Collaborator

zhiyuan1i commented Jan 18, 2026

Thanks for the contribution! @slowlyC
However, I'd like to keep tf32 precision for the dh backward matmul.
The tf32 usage here was intentional: b_q is typically L2-normalized with values in the 0-1 range, and when multiplied by small gate values, bf16 precision becomes insufficient.
I've compared the results and observed that dh0 error increases by ~17% (diff) and ~31% (ratio) with this change. This is particularly concerning because errors in dh accumulate across chunks during the backward pass - each chunk's gradient computation depends on the dh state from subsequent chunks.
This becomes even more critical in context parallelism (CP) for ultra-long sequences, where earlier ranks depend on dh0 from later ranks to compute their gradients. The precision loss would propagate and accumulate across all CP ranks, potentially affecting training stability at scale.
KDA-related ops also use tf32 matmul for the same reason. Unless there's a significant performance bottleneck here, I'd prefer to keep the current precision.
1768748966752
1768748979081

The test is only tested on 2048 tokens. CC @yzhangcs

@zhiyuan1i zhiyuan1i added the help wanted Extra attention is needed label Jan 18, 2026
@zhiyuan1i zhiyuan1i force-pushed the main branch 3 times, most recently from 2b3db51 to 53dda79 Compare January 22, 2026 07:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

help wanted Extra attention is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants