-
Notifications
You must be signed in to change notification settings - Fork 22
Current scaling: two-stage Triton amax kernel #385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener 51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener ae35e4c
bugfix graph capture
matthiasdiener 77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener c0d8e73
outline workspace allocation
matthiasdiener 6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 3c9de07
Proper allocation of workspace
matthiasdiener 91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener 8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 73c8d4e
2-stage Triton amax
matthiasdiener 6388604
add copyrights
matthiasdiener 9e6586f
don't add extra template to kernel
matthiasdiener 18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener a389455
update remaining calls to nvte_compute_amax
matthiasdiener d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 7d9ee16
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener fd5dead
additional copyrights
matthiasdiener 16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener 50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener f933ef3
Factor workspace allocation into function
matthiasdiener 7d4054e
expand test slightly
matthiasdiener 63cff98
Revert "expand test slightly"
c7d44a7
guard by HIP macro, address review comments
matthiasdiener f92b926
bugfix workspace.data.dptr
matthiasdiener eba552e
various cleanups
matthiasdiener 0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 19901a0
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener 8eda427
simplify types in allocate_amax_workspace
matthiasdiener be6496b
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener ed1a54b
Fixes
matthiasdiener c8d5bb4
add support for NVTE_USE_ATOMIC_AMAX
matthiasdiener 5a9086a
Fuse amax_reduce + compute_scale kernels
matthiasdiener 6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 9ee618f
fix indentation
matthiasdiener 853bb77
Merge branch 'speedup-amax-kernel' into speedup-amax-triton
matthiasdiener cf402b1
undo non-triton changes
matthiasdiener 2c9cc65
[ROCm] use at::empty(0, fp32) as amax workspace for makeTransformerEn…
wangye805 e41e1d4
Merge branch 'dev' into speedup-amax-triton
matthiasdiener 862ec74
Merge branch 'yewang12/amax-workspace-fix' into speedup-amax-triton
matthiasdiener 35f2d38
add more tests
matthiasdiener 1cbb68f
Merge branch 'dev' into speedup-amax-triton
matthiasdiener d7259d1
add more tests and re-add comment
matthiasdiener 42c7ac3
Merge branch 'dev' into speedup-amax-triton
matthiasdiener ef31ef7
Merge branch 'dev' into speedup-amax-triton
matthiasdiener 25c91e8
restore FP8 current scaling support
matthiasdiener 188b7ca
add test comparing atomic amax and 2-stage
matthiasdiener File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -56,9 +56,6 @@ def te_quantize_triton( | |||
| Quantizes the input tensor using a specified quantizer, | ||||
| with an option to utilize Triton-based `cast_transpose` for performance. | ||||
| """ | ||||
| from ..tensor.float8_tensor import Float8CurrentScalingQuantizer | ||||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wangye805 Do you remember why current scaling was disabled here (in #374)?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recall I moved this line to
|
||||
| if isinstance(quantizer, Float8CurrentScalingQuantizer): | ||||
| return tex.quantize(tensor, quantizer, output, noop_flag) | ||||
| input_tensor = tensor.contiguous() | ||||
| fake_tensor_type = input_tensor.dtype | ||||
| if not fake_tensor_type.is_floating_point: | ||||
|
|
||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The triton path is not enabled by default so I think you will need to test with both NVTE_USE_ATOMIC_AMAX=1 and NVTE_USE_ATOMIC_AMAX=0 when NVTE_USE_CAST_TRANSPOSE_TRITON is 1.
Also not sure about the runtime cost of adding two new pytests in level 3
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added both cases in d7259d1.
test_numerics.py takes about 5 min, test_fusible_ops.py takes about 1 min (on gfx942), times 2 since we run it with NVTE_USE_ATOMIC_AMAX=0 and =1. Perhaps adding just the test in 188b7ca is enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 mins sounds okay for level 3. @ipanfilo , what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussing with @wenchenvincent, we concluded that it is worth keeping the extra tests around.