Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c15d93b
Current scaling: two-stage amax kernel
matthiasdiener Nov 12, 2025
51fab36
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 13, 2025
ae35e4c
bugfix graph capture
matthiasdiener Nov 13, 2025
77a68a7
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 17, 2025
c0d8e73
outline workspace allocation
matthiasdiener Nov 17, 2025
6c3507d
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 18, 2025
3c9de07
Proper allocation of workspace
matthiasdiener Nov 18, 2025
91249cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
be0e0c8
add a test to compare the accuracy of both amax implementations
matthiasdiener Nov 19, 2025
bce34da
add possibility to force using previous (atomic) kernel
matthiasdiener Nov 19, 2025
8c388cc
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 19, 2025
6388604
add copyrights
matthiasdiener Nov 20, 2025
9e6586f
don't add extra template to kernel
matthiasdiener Nov 20, 2025
18292bf
make amax_kernel_threads usable in pytorch
matthiasdiener Nov 21, 2025
a389455
update remaining calls to nvte_compute_amax
matthiasdiener Nov 21, 2025
d87ab8a
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 24, 2025
fd5dead
additional copyrights
matthiasdiener Nov 24, 2025
16d3bf9
avoid workspace allocations if NVTE_USE_ATOMIC_AMAX is set
matthiasdiener Nov 24, 2025
50b34aa
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 25, 2025
ef532b1
remove use_block_amax parameter, more cleanups
matthiasdiener Nov 25, 2025
f933ef3
Factor workspace allocation into function
matthiasdiener Nov 25, 2025
7d4054e
expand test slightly
matthiasdiener Nov 25, 2025
63cff98
Revert "expand test slightly"
Nov 25, 2025
c7d44a7
guard by HIP macro, address review comments
matthiasdiener Nov 26, 2025
f92b926
bugfix workspace.data.dptr
matthiasdiener Nov 26, 2025
eba552e
various cleanups
matthiasdiener Nov 26, 2025
0d6a177
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Nov 26, 2025
8eda427
simplify types in allocate_amax_workspace
matthiasdiener Nov 26, 2025
6990928
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
9ee618f
fix indentation
matthiasdiener Dec 1, 2025
77b1bc3
Merge branch 'dev' into speedup-amax-kernel
matthiasdiener Dec 1, 2025
1357d4b
Use private implementation of DIVUP
matthiasdiener Dec 2, 2025
01b61b5
define amax_kernel_threads on non-AMD
matthiasdiener Dec 2, 2025
ed16f8f
Revert "Use private implementation of DIVUP"
matthiasdiener Dec 2, 2025
95dcbdf
Factor out workspace size calculation
matthiasdiener Dec 2, 2025
b07edf6
change name
matthiasdiener Dec 2, 2025
233eb0a
add copyright
matthiasdiener Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t s

#ifdef __HIP_PLATFORM_AMD__

size_t nvte_amax_workspace_size(size_t N);

/*! \brief Compute an FP8 tensor's amax.
*
* The amax (maximum absolute value) of the input tensor is computed
Expand Down
21 changes: 17 additions & 4 deletions transformer_engine/common/recipe/current_scaling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ using bf16__ = __nv_bfloat16;
using bf16__ = __hip_bfloat16;
#endif //__HIP_PLATFORM_AMD__

#ifndef __HIP_PLATFORM_AMD__
// Defined in include/transformer_engine/recipe.h for AMD
constexpr int amax_kernel_threads = 512;
#endif

#ifdef __HIP_PLATFORM_AMD__

Expand Down Expand Up @@ -125,13 +122,16 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
auto align = CheckAlignment(N, nvec, input);
size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));

#ifndef __HIP_PLATFORM_AMD__
// Figure out CUDA blocks
constexpr size_t threads = amax_kernel_threads;
size_t num_blocks = DIVUP(num_aligned_elements, threads);
constexpr size_t max_blocks = 65535;
num_blocks = std::min(num_blocks, max_blocks);

#ifdef __HIP_PLATFORM_AMD__
#else
constexpr size_t threads = amax_kernel_threads;
size_t num_blocks = nvte_amax_workspace_size(num_aligned_elements);
if (block_capacity < num_blocks)
block_amax = nullptr;
#endif
Expand Down Expand Up @@ -186,6 +186,19 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo
} // namespace
} // namespace transformer_engine


#ifdef __HIP_PLATFORM_AMD__

size_t nvte_amax_workspace_size(size_t N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ambiguous workspace_size - in TE it is usually byte size but here number of float32 elements is returned.
It should either return bytes and cast to float only when launch kernels, or method should be renamed to indicate it is float elements number.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I changed it to nvte_amax_workspace_num_blocks in b07edf6.

constexpr size_t max_blocks_hw = 65535;

size_t max_blocks = transformer_engine::DIVUP(N, static_cast<size_t>(amax_kernel_threads));
size_t workspace_blocks = std::min(max_blocks, max_blocks_hw);
return workspace_blocks;
}

#endif

void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
nvte_compute_amax_with_workspace(input_, output_, /*workspace=*/nullptr, stream);
Expand Down
5 changes: 1 addition & 4 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,7 @@ TensorWrapper allocate_amax_workspace(const TensorWrapper& input_tensor) {
}

const auto N = input_tensor.numel();
constexpr size_t max_blocks_hw = 65535;

size_t max_blocks = DIVUP(N, static_cast<size_t>(amax_kernel_threads));
size_t workspace_blocks = std::min(max_blocks, max_blocks_hw);
size_t workspace_blocks = nvte_amax_workspace_size(N);

at::Tensor ws = at::empty(workspace_blocks, at::CUDA(at::kFloat));

Expand Down
Loading