@@ -28,10 +28,38 @@ using bf16__ = __hip_bfloat16;
2828
2929constexpr int amax_kernel_threads = 512 ;
3030
31+ #ifdef __HIP_PLATFORM_AMD__
32+
33+ template <int BLOCK_THREADS>
34+ __global__ void amax_final_reduce (const float * __restrict__ block_amax,
35+ float * __restrict__ global_amax,
36+ int num_blocks) {
37+ float val = 0 .f ;
38+
39+ for (int i = threadIdx .x ; i < num_blocks; i += BLOCK_THREADS) {
40+ val = fmaxf (val, block_amax[i]);
41+ }
42+
43+ const int warp_id = threadIdx .x / THREADS_PER_WARP;
44+ const float block_max =
45+ reduce_max<BLOCK_THREADS / THREADS_PER_WARP>(val, warp_id);
46+
47+ if (threadIdx .x == 0 ) {
48+ *global_amax = block_max;
49+ }
50+ }
51+
52+ #endif
53+
3154template <int nvec, bool aligned, typename InputType>
3255__launch_bounds__ (amax_kernel_threads) __global__
56+ #ifdef __HIP_PLATFORM_AMD__
57+ void amax_kernel (const InputType *input, float *amax, float * __restrict__ block_amax, const size_t N,
58+ const size_t num_aligned_elements) {
59+ #else
3360 void amax_kernel (const InputType *input, float *amax, const size_t N,
3461 const size_t num_aligned_elements) {
62+ #endif
3563 VectorizedLoader<InputType, nvec, aligned> loader (input, N);
3664 InputType max{0 .f };
3765 const int warp_id = threadIdx .x / THREADS_PER_WARP;
@@ -65,12 +93,23 @@ __launch_bounds__(amax_kernel_threads) __global__
6593 // Reduce amax over block
6694 max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
6795 if (threadIdx .x == 0 ) {
96+ #ifdef __HIP_PLATFORM_AMD__
97+ if (block_amax != nullptr ) {
98+ // 2-stage: write per-block result
99+ block_amax[blockIdx .x ] = max;
100+ } else {
101+ // Atomic path: directly update global amax
102+ atomicMaxFloat (amax, max);
103+ }
104+ #else
68105 atomicMaxFloat (amax, max);
106+ #endif
69107 }
70108}
71109
72110template <int nvec, typename InputType>
73- void launch_amax_kernel (const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
111+ void launch_amax_kernel (const InputType *input, float *amax, const size_t N, float *block_amax,
112+ size_t block_capacity, cudaStream_t stream) {
74113 // Zero out amax so we can update with atomic max
75114 (void )cudaMemsetAsync (amax, 0 , sizeof (float ), stream);
76115
@@ -83,38 +122,90 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
83122 auto align = CheckAlignment (N, nvec, input);
84123 size_t num_aligned_elements = get_num_aligned_elements (input, N, nvec, sizeof (InputType));
85124
125+ #ifndef __HIP_PLATFORM_AMD__
86126 // Figure out CUDA blocks
87127 constexpr size_t threads = amax_kernel_threads;
88128 size_t num_blocks = DIVUP (num_aligned_elements, threads);
89129 constexpr size_t max_blocks = 65535 ;
90130 num_blocks = std::min (num_blocks, max_blocks);
91131
132+ #else
133+ constexpr size_t threads = amax_kernel_threads;
134+ size_t num_blocks = nvte_amax_workspace_num_blocks (num_aligned_elements);
135+ if (block_capacity < num_blocks)
136+ block_amax = nullptr ;
137+ #endif
138+
92139 // Launch kernel
93140 switch (align) {
94141 case Alignment::SAME_ALIGNED:
142+ #ifdef __HIP_PLATFORM_AMD__
143+ amax_kernel<nvec, true , InputType>
144+ <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
145+ #else
95146 amax_kernel<nvec, true , InputType>
96147 <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements);
148+ #endif
97149 break ;
98150 case Alignment::SAME_UNALIGNED:
151+ #ifdef __HIP_PLATFORM_AMD__
152+ amax_kernel<nvec, false , InputType>
153+ <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
154+ #else
99155 amax_kernel<nvec, false , InputType>
100156 <<<num_blocks, threads, 0 , stream>>> (input, amax, N, num_aligned_elements);
157+ #endif
101158 break ;
102159 case Alignment::DIFFERENT: {
103160 // This case is a logic error, since there is only one pointer (input)
104161 // in the alignment check. Still safe to process without vectorization.
162+ #ifdef __HIP_PLATFORM_AMD__
163+ amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, N);
164+ #else
105165 amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, amax, N, N);
166+ #endif
106167 break ;
107168 }
108169 }
109170
171+ #ifdef __HIP_PLATFORM_AMD__
172+ if (block_amax != nullptr ) {
173+ constexpr int FINAL_REDUCE_THREADS = 256 ;
174+ dim3 fr_block (FINAL_REDUCE_THREADS);
175+ dim3 fr_grid (1 );
176+
177+ amax_final_reduce<FINAL_REDUCE_THREADS>
178+ <<<fr_grid, fr_block, 0 , stream>>> (block_amax, amax, static_cast <int >(num_blocks));
179+ }
180+ #endif
181+
110182 // Check results
111183 NVTE_CHECK_CUDA (cudaGetLastError ());
112184}
113185
114186} // namespace
115187} // namespace transformer_engine
116188
189+
190+ #ifdef __HIP_PLATFORM_AMD__
191+
192+ size_t nvte_amax_workspace_num_blocks (size_t N) {
193+ constexpr size_t max_blocks_hw = 65535 ;
194+
195+ size_t max_blocks = transformer_engine::DIVUP (N, static_cast <size_t >(amax_kernel_threads));
196+ size_t workspace_blocks = std::min (max_blocks, max_blocks_hw);
197+ return workspace_blocks;
198+ }
199+
200+ #endif
201+
117202void nvte_compute_amax (const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
203+ #ifdef __HIP_PLATFORM_AMD__
204+ nvte_compute_amax_with_workspace (input_, output_, /* workspace=*/ nullptr , stream);
205+ }
206+
207+ void nvte_compute_amax_with_workspace (const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) {
208+ #endif
118209 NVTE_API_CALL (nvte_compute_amax);
119210 using namespace transformer_engine ;
120211
@@ -150,11 +241,31 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
150241 to_string (output.amax .dtype ), " )" );
151242 CheckOutputTensor (output, " output_compute_amax" , true );
152243
244+ #ifdef __HIP_PLATFORM_AMD__
245+ // Optional workspace
246+ float * block_amax = nullptr ;
247+ size_t block_capacity = 0 ;
248+
249+ if (workspace_ != nullptr ) {
250+ auto &workspace = *reinterpret_cast <Tensor *>(workspace_);
251+ if (workspace.data .dptr != nullptr ) {
252+ NVTE_CHECK (workspace.data .dtype == DType::kFloat32 ,
253+ " Workspace tensor for amax computation must be FP32, got dtype=" ,
254+ to_string (workspace.data .dtype ));
255+ block_amax = reinterpret_cast <float *>(workspace.data .dptr );
256+ block_capacity = workspace.data .numel ();
257+ }
258+ }
259+ #endif
260+
153261 // Compute amax
154262 TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
155263 input.data .dtype , IType, constexpr int nvec = 32 / sizeof (IType);
156264 launch_amax_kernel<nvec>(reinterpret_cast <const IType *>(input.data .dptr ),
157265 reinterpret_cast <float *>(output.amax .dptr ), input.data .numel (),
266+ #ifdef __HIP_PLATFORM_AMD__
267+ block_amax, block_capacity,
268+ #endif
158269 stream);); // NOLINT(*)
159270}
160271
0 commit comments