-
Notifications
You must be signed in to change notification settings - Fork 22
Current scaling: two-stage Triton amax kernel #385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This reverts commit 7d4054e.
43cf8ab to
cf402b1
Compare
| NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py | ||
| NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py | ||
| NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py | ||
| NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.
Also not sure about the runtime cost of adding two new pytests in level 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.
I added both cases in d7259d1.
Also not sure about the runtime cost of adding two new pytests in level 3
test_numerics.py takes about 5 min, test_fusible_ops.py takes about 1 min (on gfx942), times 2 since we run it with NVTE_USE_ATOMIC_AMAX=0 and =1. Perhaps adding just the test in 188b7ca is enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 mins sounds okay for level 3. @ipanfilo , what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussing with @wenchenvincent, we concluded that it is worth keeping the extra tests around.
| NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py | ||
| NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py | ||
| NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py | ||
| NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.
I added both cases in d7259d1.
Also not sure about the runtime cost of adding two new pytests in level 3
test_numerics.py takes about 5 min, test_fusible_ops.py takes about 1 min (on gfx942), times 2 since we run it with NVTE_USE_ATOMIC_AMAX=0 and =1. Perhaps adding just the test in 188b7ca is enough?
| Quantizes the input tensor using a specified quantizer, | ||
| with an option to utilize Triton-based `cast_transpose` for performance. | ||
| """ | ||
| from ..tensor.float8_tensor import Float8CurrentScalingQuantizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wangye805 Do you remember why current scaling was disabled here (in #374)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recall I moved this line to
| from ..tensor.float8_tensor import Float8CurrentScalingQuantizer |
Description
The corresponding 2-stage HIP kernel amax implementation is in #369.
Partially addresses https://github.com/ROCm/frameworks-internal/issues/14303.
See https://github.com/ROCm/frameworks-internal/issues/14303#issuecomment-3554900809 for a performance analysis.
Type of change
Changes
export NVTE_USE_ATOMIC_AMAX=1(this will use the previous atomic implementation)Checklist: