Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Sep 5, 2025

Description

This PR disentangles the backend triton implementation from the front-end API, creating a unified intermediate te_norm_fwd_triton which is a generalized dispatch function. This PR is fully backwards compatible, as te_rmsnorm_fwd_triton and te_layernorm_fwd_triton are preserved and implemented as thin wrappers around te_norm_fwd_triton.

This way, when bugs appear, we fix them once without needing to duplicate across norms.

Consequently, there are some changes to the imports to accommodate this restructuring. This PR also includes a minor cleanup/simplification of previously redundant behavior in the layernorm fwd implementation, as well as support for Float8CurrentScalingQuantizer.

FWIW I don't think we can apply a similar unification to the backwards passes, as it seems that -- at least for layernorm -- the backwards implementations are pretty specialized and have asymmetric heuristics.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the Triton normalization (RMSNorm and LayerNorm) implementations by creating a unified dispatch mechanism. It introduces a new te_norm_fwd_triton function that serves as a generalized entry point for both norm types, while preserving backward compatibility by maintaining the existing API functions as thin wrappers.

Key changes include:

  • Created a unified te_norm_fwd_triton dispatch function in a new norms.py file
  • Modified kernel signatures to support both RMSNorm and LayerNorm use cases
  • Updated imports across multiple modules to reference the new consolidated location

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
transformer_engine/pytorch/triton_kernels/norms.py New file containing unified norm dispatch logic and relocated function implementations
transformer_engine/pytorch/triton_kernels/rmsnorm.py Removed te_rmsnorm_fwd_triton function and updated kernel signature for unification
transformer_engine/pytorch/triton_kernels/layernorm.py Removed forward/backward functions and simplified reduction kernel signature
transformer_engine/pytorch/ops/basic/rmsnorm.py Updated import to reference new norms module
transformer_engine/pytorch/ops/basic/layer_norm.py Updated import to reference new norms module
transformer_engine/pytorch/module/layernorm_mlp.py Consolidated imports from new norms module
transformer_engine/pytorch/module/layernorm_linear.py Consolidated imports from new norms module
transformer_engine/pytorch/module/_common.py Consolidated imports from new norms module
tests/pytorch/triton_kernels/test_norms.py Updated imports to reference new norms module

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Micky774 Micky774 marked this pull request as draft September 5, 2025 22:35
@Micky774 Micky774 marked this pull request as ready for review September 5, 2025 23:09
@Micky774
Copy link
Contributor Author

Micky774 commented Sep 5, 2025

Note that currently some of the layernorm tests are failing, but they're citing NaN vals in the expected tensor, i.e. the HIP reference kernel. I tried including @alextmagro's PR #303 but it still fails. @alextmagro have you seen such an error as well? Is it something related?

@alextmagro
Copy link
Contributor

Note that currently some of the layernorm tests are failing, but they're citing NaN vals in the expected tensor, i.e. the HIP reference kernel. I tried including @alextmagro's PR #303 but it still fails. @alextmagro have you seen such an error as well? Is it something related?

I haven't seen anything like that.

@@ -0,0 +1,348 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright

@Micky774
Copy link
Contributor Author

Micky774 commented Sep 8, 2025

Turns out something in this PR makes it so that the layernorm kernel has bad memory behavior. Specifically, it mutates either the weight tensor, or the bias tensor in the test. This happens, I believe, because they are allocated on GPu contiguously wrt each other (i.e. first input array, then gamma, then bias) which leads me to suspect that there's some kind of masking problem with the layernorm kernel, but I have not been able to pinpoint it yet.

Everything seems to work on dev, but I don't have any functional changes aside from a logical simplification of non-atomic layernorm fwd cases. Most of it is just variable renaming.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants