Skip to content

Conversation

@matthiasdiener
Copy link
Contributor

@matthiasdiener matthiasdiener commented Nov 26, 2025

Description

The corresponding 2-stage HIP kernel amax implementation is in #369.

Partially addresses https://github.com/ROCm/frameworks-internal/issues/14303.

See https://github.com/ROCm/frameworks-internal/issues/14303#issuecomment-3554900809 for a performance analysis.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added a 2-stage Triton amax implementation, optional, but enabled by default
    • Disable by setting export NVTE_USE_ATOMIC_AMAX=1 (this will use the previous atomic implementation)

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

matthiasdiener and others added 30 commits November 12, 2025 14:10
This reverts commit 7d4054e.
@matthiasdiener matthiasdiener changed the base branch from speedup-amax-kernel to dev November 26, 2025 22:53
@matthiasdiener matthiasdiener changed the title Speedup amax triton [WIP] Speedup amax triton Nov 27, 2025
@matthiasdiener matthiasdiener changed the title [WIP] Speedup amax triton Current scaling: two-stage Triton amax kernel Dec 1, 2025
@matthiasdiener matthiasdiener marked this pull request as ready for review December 1, 2025 20:00
NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.

Also not sure about the runtime cost of adding two new pytests in level 3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.

I added both cases in d7259d1.

Also not sure about the runtime cost of adding two new pytests in level 3

test_numerics.py takes about 5 min, test_fusible_ops.py takes about 1 min (on gfx942), times 2 since we run it with NVTE_USE_ATOMIC_AMAX=0 and =1. Perhaps adding just the test in 188b7ca is enough?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 mins sounds okay for level 3. @ipanfilo , what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussing with @wenchenvincent, we concluded that it is worth keeping the extra tests around.

NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.

I added both cases in d7259d1.

Also not sure about the runtime cost of adding two new pytests in level 3

test_numerics.py takes about 5 min, test_fusible_ops.py takes about 1 min (on gfx942), times 2 since we run it with NVTE_USE_ATOMIC_AMAX=0 and =1. Perhaps adding just the test in 188b7ca is enough?

Quantizes the input tensor using a specified quantizer,
with an option to utilize Triton-based `cast_transpose` for performance.
"""
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangye805 Do you remember why current scaling was disabled here (in #374)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recall I moved this line to

from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
instead of disabling curent scaling quantizer entirely, in order to resolve a circular inclusion issue

@matthiasdiener matthiasdiener merged commit bdd6c63 into dev Dec 20, 2025
5 of 6 checks passed
@matthiasdiener matthiasdiener deleted the speedup-amax-triton branch December 20, 2025 09:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants