-
Notifications
You must be signed in to change notification settings - Fork 23
Current scaling: two-stage HIP amax kernel #369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 20 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener 51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener ae35e4c
bugfix graph capture
matthiasdiener 77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener c0d8e73
outline workspace allocation
matthiasdiener 6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 3c9de07
Proper allocation of workspace
matthiasdiener 91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener 8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 6388604
add copyrights
matthiasdiener 9e6586f
don't add extra template to kernel
matthiasdiener 18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener a389455
update remaining calls to nvte_compute_amax
matthiasdiener d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener fd5dead
additional copyrights
matthiasdiener 16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener 50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener f933ef3
Factor workspace allocation into function
matthiasdiener 7d4054e
expand test slightly
matthiasdiener 63cff98
Revert "expand test slightly"
c7d44a7
guard by HIP macro, address review comments
matthiasdiener f92b926
bugfix workspace.data.dptr
matthiasdiener eba552e
various cleanups
matthiasdiener 0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 8eda427
simplify types in allocate_amax_workspace
matthiasdiener 6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 9ee618f
fix indentation
matthiasdiener 77b1bc3
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener 1357d4b
Use private implementation of DIVUP
matthiasdiener 01b61b5
define amax_kernel_threads on non-AMD
matthiasdiener ed16f8f
Revert "Use private implementation of DIVUP"
matthiasdiener 95dcbdf
Factor out workspace size calculation
matthiasdiener b07edf6
change name
matthiasdiener 233eb0a
add copyright
matthiasdiener File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,11 +26,29 @@ using bf16__ = __nv_bfloat16; | |
| using bf16__ = __hip_bfloat16; | ||
| #endif //__HIP_PLATFORM_AMD__ | ||
|
|
||
| constexpr int amax_kernel_threads = 512; | ||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| template <int BLOCK_THREADS> | ||
| __global__ void amax_final_reduce(const float* __restrict__ block_amax, | ||
| float* __restrict__ global_amax, | ||
| int num_blocks) { | ||
| float val = 0.f; | ||
|
|
||
| for (int i = threadIdx.x; i < num_blocks; i += BLOCK_THREADS) { | ||
| val = fmaxf(val, block_amax[i]); | ||
| } | ||
|
|
||
| const int warp_id = threadIdx.x / THREADS_PER_WARP; | ||
| const float block_max = | ||
| reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id); | ||
|
|
||
| if (threadIdx.x == 0) { | ||
| *global_amax = block_max; | ||
| } | ||
| } | ||
|
|
||
| template <int nvec, bool aligned, typename InputType> | ||
| __launch_bounds__(amax_kernel_threads) __global__ | ||
| void amax_kernel(const InputType *input, float *amax, const size_t N, | ||
| void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N, | ||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| const size_t num_aligned_elements) { | ||
| VectorizedLoader<InputType, nvec, aligned> loader(input, N); | ||
| InputType max{0.f}; | ||
|
|
@@ -65,12 +83,19 @@ __launch_bounds__(amax_kernel_threads) __global__ | |
| // Reduce amax over block | ||
| max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id); | ||
| if (threadIdx.x == 0) { | ||
| atomicMaxFloat(amax, max); | ||
| if (block_amax != nullptr) { | ||
| // 2-stage: write per-block result | ||
| block_amax[blockIdx.x] = max; | ||
| } else { | ||
| // Atomic path: directly update global amax | ||
| atomicMaxFloat(amax, max); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <int nvec, typename InputType> | ||
| void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { | ||
| void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax, | ||
| size_t block_capacity, cudaStream_t stream) { | ||
| // Zero out amax so we can update with atomic max | ||
| (void)cudaMemsetAsync(amax, 0, sizeof(float), stream); | ||
|
|
||
|
|
@@ -89,24 +114,38 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud | |
| constexpr size_t max_blocks = 65535; | ||
| num_blocks = std::min(num_blocks, max_blocks); | ||
|
|
||
| const bool UseBlockAmax = | ||
| (block_amax != nullptr) && | ||
| (block_capacity >= num_blocks) && | ||
| !nvte_use_atomic_amax(); | ||
|
||
|
|
||
| // Launch kernel | ||
| switch (align) { | ||
| case Alignment::SAME_ALIGNED: | ||
| amax_kernel<nvec, true, InputType> | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements); | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements); | ||
| break; | ||
| case Alignment::SAME_UNALIGNED: | ||
| amax_kernel<nvec, false, InputType> | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements); | ||
| <<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, num_aligned_elements); | ||
| break; | ||
| case Alignment::DIFFERENT: { | ||
| // This case is a logic error, since there is only one pointer (input) | ||
| // in the alignment check. Still safe to process without vectorization. | ||
| amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N); | ||
| amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, block_amax, N, N); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (UseBlockAmax) { | ||
| constexpr int FINAL_REDUCE_THREADS = 256; | ||
| dim3 fr_block(FINAL_REDUCE_THREADS); | ||
| dim3 fr_grid(1); | ||
|
|
||
| amax_final_reduce<FINAL_REDUCE_THREADS> | ||
| <<<fr_grid, fr_block, 0, stream>>>(block_amax, amax, static_cast<int>(num_blocks)); | ||
| } | ||
|
|
||
| // Check results | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||
| } | ||
|
|
@@ -115,6 +154,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud | |
| } // namespace transformer_engine | ||
|
|
||
| void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { | ||
| nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream); | ||
| } | ||
|
|
||
| void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) { | ||
| NVTE_API_CALL(nvte_compute_amax); | ||
| using namespace transformer_engine; | ||
|
|
||
|
|
@@ -150,11 +193,27 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt | |
| to_string(output.amax.dtype), ")"); | ||
| CheckOutputTensor(output, "output_compute_amax", true); | ||
|
|
||
| // Optional workspace | ||
| float* block_amax = nullptr; | ||
| size_t block_capacity = 0; | ||
|
|
||
| if (workspace_ != nullptr) { | ||
| auto &workspace = *reinterpret_cast<Tensor *>(workspace_); | ||
| NVTE_CHECK(workspace.data.dptr != nullptr, | ||
| "Workspace tensor for amax computation has no data"); | ||
| NVTE_CHECK(workspace.data.dtype == DType::kFloat32, | ||
| "Workspace tensor for amax computation must be FP32, got dtype=", | ||
| to_string(workspace.data.dtype)); | ||
| block_amax = reinterpret_cast<float*>(workspace.data.dptr); | ||
| block_capacity = workspace.data.numel(); | ||
| } | ||
|
|
||
| // Compute amax | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( | ||
| input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); | ||
| launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr), | ||
| reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), | ||
| reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), block_amax, | ||
| block_capacity, | ||
| stream);); // NOLINT(*) | ||
| } | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.