@@ -381,4 +381,163 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
381381 }
382382}
383383
384+ // Forward declaration of functions defined in `cast_kernels.cuh`
385+ template <typename IType>
386+ void reduce_dbias (const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
387+ cudaStream_t stream);
388+
389+ template <typename ParamOP, float (*OP)(float , const ParamOP &)>
390+ void CastVectorizedUnaryKernelLauncher (const Tensor &input, const Tensor *noop, Tensor *output,
391+ cudaStream_t stream);
392+
393+ template <typename ParamOP, float (*OP)(float , const ParamOP &)>
394+ void CastVectorizedUnaryGradKernelLauncher (const Tensor &grad, const Tensor *input, Tensor *output,
395+ cudaStream_t stream);
396+
397+ constexpr size_t TILE_DIM = 32 ;
398+ template <typename DTypeReduce>
399+ __global__ void partial_reduce_kernel (const DTypeReduce* input, float * partial_output, int rows, int cols) {
400+ __shared__ float tile[TILE_DIM][TILE_DIM];
401+
402+ int tile_start_col = blockIdx .x * TILE_DIM;
403+ int tile_start_row = blockIdx .y * TILE_DIM;
404+ int thread_col_in_tile = threadIdx .x ;
405+ int thread_row_in_tile = threadIdx .y ;
406+
407+ int global_col = tile_start_col + thread_col_in_tile;
408+ int global_row = tile_start_row + thread_row_in_tile;
409+
410+ if (global_row < rows && global_col < cols) {
411+ tile[thread_row_in_tile][thread_col_in_tile] = static_cast <float >(input[global_row * cols + global_col]);
412+ } else {
413+ tile[thread_row_in_tile][thread_col_in_tile] = 0 .0f ;
414+ }
415+ __syncthreads ();
416+
417+ for (int stride = TILE_DIM / 2 ; stride > 0 ; stride /= 2 ) {
418+ if (thread_row_in_tile < stride) {
419+ tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile];
420+ }
421+ __syncthreads ();
422+ }
423+
424+ if (thread_row_in_tile == 0 && global_col < cols) {
425+ partial_output[blockIdx .y * cols + global_col] = tile[0 ][thread_col_in_tile];
426+ }
427+ }
428+
429+ template <typename DTypeReduce, typename DBiasTypeOut>
430+ void reduce_dbias_rocm (const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows,
431+ const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) {
432+ dim3 block_dim_partial (TILE_DIM, TILE_DIM);
433+ dim3 grid_dim_partial (DIVUP (cols, TILE_DIM), DIVUP (rows, TILE_DIM));
434+
435+ const size_t partial_rows = grid_dim_partial.y ;
436+ float * partial_workspace = reinterpret_cast <float *>(partial_sum_workspace->data .dptr );
437+
438+ partial_reduce_kernel<DTypeReduce><<<grid_dim_partial, block_dim_partial, 0 , stream>>> (
439+ workspace_ptr,
440+ partial_workspace,
441+ rows, cols);
442+
443+ reduce_dbias<DBiasTypeOut>(partial_workspace, dbias, partial_rows, cols, stream);
444+ }
445+
446+ template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
447+ float (*OP)(float , const ParamOP &)>
448+ void fp8_quantize_rocm (const Tensor &input, const Tensor *act_input, const Tensor *noop,
449+ Tensor *output, Tensor *dbias, Tensor *workspace,
450+ cudaStream_t stream) {
451+ switch (output->scaling_mode ) {
452+ case NVTE_DELAYED_TENSOR_SCALING: {
453+ const size_t rows = input.flat_first_dim ();
454+ const size_t cols = input.flat_last_dim ();
455+
456+ if constexpr (IS_DBIAS) {
457+ NVTE_CHECK (dbias, " DBias tensor must be provided when IS_DBIAS is true." );
458+ NVTE_CHECK (workspace, " Workspace must be provided when IS_DBIAS is true." );
459+ if (workspace->data .dptr == nullptr ) {
460+ if constexpr (IS_DACT) {
461+ const size_t partial_rows = DIVUP (rows, TILE_DIM);
462+ size_t total_elements = (rows * cols) + (partial_rows * cols);
463+ workspace->data .shape = {total_elements};
464+ workspace->data .dtype = DType::kFloat32 ;
465+ } else {
466+ workspace->data .shape = {rows, cols};
467+ workspace->data .dtype = DType::kFloat32 ;
468+ }
469+ return ;
470+ }
471+
472+ const void *ptr_to_reduce = nullptr ;
473+ DType dtype_to_reduce;
474+
475+ workspace->amax = {};
476+ workspace->scale = {};
477+ workspace->scale_inv = {};
478+
479+ Tensor workspace_buffer;
480+ Tensor partial_sum_buffer;
481+
482+ if constexpr (IS_DACT) {
483+ // The values to reduce are the result of the dAct function.
484+ NVTE_CHECK (act_input, " Gradient tensor must be provided for DBias + DACT." );
485+
486+ const size_t partial_rows = DIVUP (rows, TILE_DIM);
487+ const size_t full_size_bytes = rows * cols * sizeof (float );
488+ workspace_buffer = *workspace;
489+ workspace_buffer.data .shape = {rows, cols};
490+ partial_sum_buffer.data .dptr = reinterpret_cast <char *>(workspace->data .dptr ) + full_size_bytes;
491+ partial_sum_buffer.data .shape = {partial_rows, cols};
492+ partial_sum_buffer.data .dtype = DType::kFloat32 ;
493+ workspace = &partial_sum_buffer;
494+
495+ CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, &workspace_buffer, stream);
496+ if (output && output->data .dptr ) {
497+ CastVectorizedUnaryKernelLauncher<transformer_engine::Empty, nullptr >(workspace_buffer, noop, output, stream);
498+ }
499+ ptr_to_reduce = workspace_buffer.data .dptr ;
500+ dtype_to_reduce = workspace_buffer.data .dtype ;
501+ } else {
502+ if (output && output->data .dptr ) {
503+ CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
504+ }
505+ // The values to reduce are just the input values.
506+ ptr_to_reduce = input.data .dptr ;
507+ dtype_to_reduce = input.data .dtype ;
508+ }
509+
510+ NVTE_CHECK (dbias->data .shape == std::vector<size_t >{cols}, " Wrong shape of DBias tensor." );
511+
512+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
513+ dbias->data .dtype , DBiasTypeOut,
514+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT (
515+ dtype_to_reduce, DTypeReduce,
516+ reduce_dbias_rocm<DTypeReduce, DBiasTypeOut>(
517+ reinterpret_cast <const DTypeReduce *>(ptr_to_reduce),
518+ dbias, rows, cols, stream, workspace);
519+ );
520+ );
521+ } else {
522+ if (output && output->data .dptr ) {
523+ if constexpr (IS_DACT) {
524+ NVTE_CHECK (act_input, " Gradient tensor must be provided for DACT output." );
525+ CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
526+ } else {
527+ CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
528+ }
529+ }
530+ }
531+ break ;
532+ }
533+ case NVTE_MXFP8_1D_SCALING: {
534+ mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, dbias,
535+ workspace, stream);
536+ break ;
537+ }
538+ default :
539+ NVTE_ERROR (" Not implemented scaling mode: " + to_string (output->scaling_mode ) + " ." );
540+ }
541+ }
542+
384543} // namespace transformer_engine
0 commit comments