-
Notifications
You must be signed in to change notification settings - Fork 23
Current scaling: two-stage HIP amax kernel #369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c15d93b
51fab36
ae35e4c
77a68a7
c0d8e73
6c3507d
3c9de07
91249cc
be0e0c8
bce34da
8c388cc
6388604
9e6586f
18292bf
a389455
d87ab8a
fd5dead
16d3bf9
50b34aa
ef532b1
f933ef3
7d4054e
63cff98
c7d44a7
f92b926
eba552e
0d6a177
8eda427
6990928
9ee618f
77b1bc3
1357d4b
01b61b5
ed16f8f
95dcbdf
b07edf6
233eb0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,10 +26,7 @@ using bf16__ = __nv_bfloat16; | |
| using bf16__ = __hip_bfloat16; | ||
| #endif //__HIP_PLATFORM_AMD__ | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| // Defined in include/transformer_engine/recipe.h for AMD | ||
| constexpr int amax_kernel_threads = 512; | ||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #endif | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
wangye805 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -125,13 +122,16 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo | |
| auto align = CheckAlignment(N, nvec, input); | ||
| size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); | ||
|
|
||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| // Figure out CUDA blocks | ||
| constexpr size_t threads = amax_kernel_threads; | ||
| size_t num_blocks = DIVUP(num_aligned_elements, threads); | ||
| constexpr size_t max_blocks = 65535; | ||
| num_blocks = std::min(num_blocks, max_blocks); | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| #else | ||
| constexpr size_t threads = amax_kernel_threads; | ||
| size_t num_blocks = nvte_amax_workspace_size(num_aligned_elements); | ||
| if (block_capacity < num_blocks) | ||
| block_amax = nullptr; | ||
| #endif | ||
|
|
@@ -186,6 +186,19 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo | |
| } // namespace | ||
| } // namespace transformer_engine | ||
|
|
||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
|
|
||
| size_t nvte_amax_workspace_size(size_t N) { | ||
|
||
| constexpr size_t max_blocks_hw = 65535; | ||
|
|
||
| size_t max_blocks = transformer_engine::DIVUP(N, static_cast<size_t>(amax_kernel_threads)); | ||
| size_t workspace_blocks = std::min(max_blocks, max_blocks_hw); | ||
| return workspace_blocks; | ||
| } | ||
|
|
||
| #endif | ||
|
|
||
| void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.