@@ -47,9 +47,9 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax,
4747 }
4848}
4949
50- template <int nvec, bool aligned, typename InputType>
50+ template <int nvec, bool aligned, typename InputType, bool UseBlockAmax >
5151__launch_bounds__ (amax_kernel_threads) __global__
52- void amax_kernel (const InputType *input, float * __restrict__ block_amax, const size_t N,
52+ void amax_kernel (const InputType *input, float *amax, float * __restrict__ block_amax, const size_t N,
5353 const size_t num_aligned_elements) {
5454 VectorizedLoader<InputType, nvec, aligned> loader (input, N);
5555 InputType max{0 .f };
@@ -85,12 +85,17 @@ __launch_bounds__(amax_kernel_threads) __global__
8585 // Reduce amax over block
8686 max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
8787 if (threadIdx .x == 0 ) {
88- block_amax[blockIdx .x ] = max;
88+ if constexpr (UseBlockAmax) {
89+ block_amax[blockIdx .x ] = max;
90+ } else {
91+ atomicMaxFloat (amax, max);
92+ }
8993 }
9094}
9195
9296template <int nvec, typename InputType>
93- void launch_amax_kernel (const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
97+ void launch_amax_kernel (const InputType *input, float *amax, const size_t N, float *block_amax,
98+ size_t block_capacity, cudaStream_t stream) {
9499 // Zero out amax so we can update with atomic max
95100 (void )cudaMemsetAsync (amax, 0 , sizeof (float ), stream);
96101
@@ -109,28 +114,43 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
109114 constexpr size_t max_blocks = 65535 ;
110115 num_blocks = std::min (num_blocks, max_blocks);
111116
112- float * block_amax = nullptr ;
113- NVTE_CHECK_CUDA (cudaMallocAsync (&block_amax, num_blocks * sizeof (float ), stream));
117+ const bool UseBlockAmax = (block_amax != nullptr );
118+
119+ if (UseBlockAmax) {
120+ NVTE_CHECK (block_capacity >= num_blocks);
121+ }
114122
115123 // Launch kernel
116124 switch (align) {
117125 case Alignment::SAME_ALIGNED:
118- amax_kernel<nvec, true , InputType>
119- <<<num_blocks, threads, 0 , stream>>> (input, block_amax, N, num_aligned_elements);
126+ // FIXME: this code is clumsy. Perhaps don't use the UseBlockAmax extra template argument
127+ if (UseBlockAmax)
128+ amax_kernel<nvec, true , InputType, true >
129+ <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
130+ else
131+ amax_kernel<nvec, true , InputType, false >
132+ <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
120133 break ;
121134 case Alignment::SAME_UNALIGNED:
122- amax_kernel<nvec, false , InputType>
123- <<<num_blocks, threads, 0 , stream>>> (input, block_amax, N, num_aligned_elements);
135+ if (UseBlockAmax)
136+ amax_kernel<nvec, false , InputType, true >
137+ <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
138+ else
139+ amax_kernel<nvec, false , InputType, false >
140+ <<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, num_aligned_elements);
124141 break ;
125142 case Alignment::DIFFERENT: {
126143 // This case is a logic error, since there is only one pointer (input)
127144 // in the alignment check. Still safe to process without vectorization.
128- amax_kernel<1 , true , InputType><<<num_blocks, threads, 0 , stream>>> (input, block_amax, N, N);
145+ if (UseBlockAmax)
146+ amax_kernel<1 , true , InputType, true ><<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, N);
147+ else
148+ amax_kernel<1 , true , InputType, false ><<<num_blocks, threads, 0 , stream>>> (input, amax, block_amax, N, N);
129149 break ;
130150 }
131151 }
132152
133- {
153+ if (UseBlockAmax) {
134154 constexpr int FINAL_REDUCE_THREADS = 256 ;
135155 dim3 fr_block (FINAL_REDUCE_THREADS);
136156 dim3 fr_grid (1 );
@@ -141,7 +161,6 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
141161
142162 // Check results
143163 NVTE_CHECK_CUDA (cudaGetLastError ());
144- NVTE_CHECK_CUDA (cudaFreeAsync (block_amax, stream));
145164}
146165
147166} // namespace
@@ -183,11 +202,20 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
183202 to_string (output.amax .dtype ), " )" );
184203 CheckOutputTensor (output, " output_compute_amax" , true );
185204
205+ // Interpret output.data as workspace if present
206+ float *block_amax = nullptr ;
207+ size_t block_capacity = 0 ;
208+ if (output.data .dptr != nullptr ) {
209+ block_amax = reinterpret_cast <float *>(output.data .dptr );
210+ block_capacity = output.data .numel (); // #floats in workspace
211+ }
212+
186213 // Compute amax
187214 TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
188215 input.data .dtype , IType, constexpr int nvec = 32 / sizeof (IType);
189216 launch_amax_kernel<nvec>(reinterpret_cast <const IType *>(input.data .dptr ),
190- reinterpret_cast <float *>(output.amax .dptr ), input.data .numel (),
217+ reinterpret_cast <float *>(output.amax .dptr ), input.data .numel (), block_amax,
218+ block_capacity,
191219 stream);); // NOLINT(*)
192220}
193221
0 commit comments