Skip to content

Commit 653b5b4

Browse files
authored
[Fix] Added dbias and dgelu kernels for ROCm (#333)
* Added guards for `getDeviceComputeCapability()` which is for NVIDIA platforms * Implemented`fp8_quantize_rocm` reusing `CastVectorizedUnaryKernelLauncher` and `CastVectorizedUnaryGradKernelLauncher` * Implemented `partial_reduce_kernel` and `reduce_dbias_rocm` to efficiently reduce large sized inputs * Enabled `test_cast_dbias` and `test_cast_dbias_dgelu` cpp tests for ROCm
1 parent 9eaaf4c commit 653b5b4

File tree

8 files changed

+183
-11
lines changed

8 files changed

+183
-11
lines changed

tests/cpp/operator/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ list(APPEND test_cuda_sources
1919
test_cast_transpose_dbias.cu
2020
test_cast_transpose_dbias_dgelu.cu
2121
test_cast_transpose_dgeglu.cu
22+
test_cast_dbias.cu
23+
test_cast_dbias_dgelu.cu
2224
test_act.cu
2325
test_normalization.cu
2426
test_normalization_mxfp8.cu
@@ -29,9 +31,7 @@ list(APPEND test_cuda_sources
2931
../test_common.cu)
3032
if(USE_CUDA)
3133
list(APPEND test_cuda_sources
32-
test_cast_float8blockwise.cu
33-
test_cast_dbias.cu
34-
test_cast_dbias_dgelu.cu)
34+
test_cast_float8blockwise.cu)
3535
else()
3636
list(APPEND test_cuda_sources
3737
test_cublaslt_gemm.cu)

tests/cpp/operator/test_cast_dbias.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,12 @@ class CastDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transforme
149149
TEST_P(CastDBiasTestSuite, TestCastDBias) {
150150
using namespace transformer_engine;
151151
using namespace test;
152+
#ifndef __HIP_PLATFORM_AMD__
152153
// Skip tests for pre-Blackwell architectures
153154
if (getDeviceComputeCapability() < blackwellComputeCapability) {
154155
GTEST_SKIP();
155156
}
157+
#endif
156158

157159
const DType input_type = std::get<0>(GetParam());
158160
const DType output_type = std::get<1>(GetParam());

tests/cpp/operator/test_cast_dbias_dgelu.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,13 @@ class CastDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<trans
164164
TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
165165
using namespace transformer_engine;
166166
using namespace test;
167+
168+
#ifndef __HIP_PLATFORM_AMD__
167169
// Skip tests for pre-Blackwell architectures
168170
if (getDeviceComputeCapability() < blackwellComputeCapability) {
169171
GTEST_SKIP();
170172
}
173+
#endif
171174

172175
const DType input_type = std::get<0>(GetParam());
173176
const DType output_type = std::get<1>(GetParam());

tests/cpp/operator/test_cast_float8blockwise.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,11 @@ class FusedCastFloat8VectorwiseTestSuite
478478
}
479479

480480
TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
481+
#ifndef __HIP_PLATFORM_AMD__
481482
if (getDeviceComputeCapability() < hopperComputeCapability) {
482483
GTEST_SKIP();
483484
}
485+
#endif
484486

485487
using namespace transformer_engine;
486488
using namespace test;
@@ -529,9 +531,11 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
529531
}
530532

531533
TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
534+
#ifndef __HIP_PLATFORM_AMD__
532535
if (getDeviceComputeCapability() < hopperComputeCapability) {
533536
GTEST_SKIP();
534537
}
538+
#endif
535539

536540
using namespace transformer_engine;
537541
using namespace test;

tests/cpp/operator/test_normalization.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
3535
return;
3636
}
3737

38+
#ifndef __HIP_PLATFORM_AMD__
3839
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
3940
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
4041
}
42+
#endif
4143

4244
using WeightType = InputType;
4345
DType itype = TypeInfo<InputType>::dtype;

tests/cpp/test_common.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,13 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
531531
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
532532
const T *ref_data = reinterpret_cast<const T*>(ref);
533533
for (size_t i = 0; i < N; ++i) {
534-
#ifndef __HIP_PLATFORM_AMD__
534+
#ifndef __HIP_PLATFORM_AMD__
535535
double t = static_cast<double>(test_data[i]);
536536
double r = static_cast<double>(ref_data[i]);
537-
#else
537+
#else
538538
double t = static_cast<double>(static_cast<float>(test_data[i]));
539539
double r = static_cast<double>(static_cast<float>(ref_data[i]));
540-
#endif
540+
#endif
541541
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
542542
/* For Float32 the floating point comparison is enough to error out */
543543
bool assertion = mismatch && test.dtype() == DType::kFloat32;

transformer_engine/common/util/cast_kernels.cuh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,6 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
12061206
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
12071207
}
12081208
}
1209-
#endif //#ifndef __HIP_PLATFORM_AMD__
12101209

12111210
// Supported by the Arch < 10.0
12121211
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
@@ -1232,6 +1231,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const
12321231
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
12331232
}
12341233
}
1234+
#endif //#ifndef __HIP_PLATFORM_AMD__
12351235

12361236
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
12371237
float (*OP)(float, const ParamOP &)>
@@ -1256,17 +1256,19 @@ void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *no
12561256
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
12571257

12581258
#ifndef __HIP_PLATFORM_AMD__
1259+
// NVIDIA
12591260
// Supported by the Arch >= 10.0
12601261
if (is_supported_by_CC_100()) {
12611262
fp8_quantize_arch_ge_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
12621263
dbias, workspace, stream);
1263-
} else {
1264-
#endif //#ifndef __HIP_PLATFORM_AMD__
1265-
// Supported by the Arch < 10.0
1264+
} else { // Supported by the Arch < 10.0
12661265
fp8_quantize_arch_l_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
12671266
dbias, workspace, stream);
1268-
#ifndef __HIP_PLATFORM_AMD__
12691267
}
1268+
#else
1269+
// AMD
1270+
fp8_quantize_rocm<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
1271+
dbias, workspace, stream);
12701272
#endif //#ifndef __HIP_PLATFORM_AMD__
12711273
}
12721274

transformer_engine/common/util/rocm_cast_kernels.cuh

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,163 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
381381
}
382382
}
383383

384+
// Forward declaration of functions defined in `cast_kernels.cuh`
385+
template <typename IType>
386+
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
387+
cudaStream_t stream);
388+
389+
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
390+
void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output,
391+
cudaStream_t stream);
392+
393+
template <typename ParamOP, float (*OP)(float, const ParamOP &)>
394+
void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output,
395+
cudaStream_t stream);
396+
397+
constexpr size_t TILE_DIM = 32;
398+
template <typename DTypeReduce>
399+
__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) {
400+
__shared__ float tile[TILE_DIM][TILE_DIM];
401+
402+
int tile_start_col = blockIdx.x * TILE_DIM;
403+
int tile_start_row = blockIdx.y * TILE_DIM;
404+
int thread_col_in_tile = threadIdx.x;
405+
int thread_row_in_tile = threadIdx.y;
406+
407+
int global_col = tile_start_col + thread_col_in_tile;
408+
int global_row = tile_start_row + thread_row_in_tile;
409+
410+
if (global_row < rows && global_col < cols) {
411+
tile[thread_row_in_tile][thread_col_in_tile] = static_cast<float>(input[global_row * cols + global_col]);
412+
} else {
413+
tile[thread_row_in_tile][thread_col_in_tile] = 0.0f;
414+
}
415+
__syncthreads();
416+
417+
for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) {
418+
if (thread_row_in_tile < stride) {
419+
tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile];
420+
}
421+
__syncthreads();
422+
}
423+
424+
if (thread_row_in_tile == 0 && global_col < cols) {
425+
partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile];
426+
}
427+
}
428+
429+
template <typename DTypeReduce, typename DBiasTypeOut>
430+
void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows,
431+
const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) {
432+
dim3 block_dim_partial(TILE_DIM, TILE_DIM);
433+
dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM));
434+
435+
const size_t partial_rows = grid_dim_partial.y;
436+
float* partial_workspace = reinterpret_cast<float*>(partial_sum_workspace->data.dptr);
437+
438+
partial_reduce_kernel<DTypeReduce><<<grid_dim_partial, block_dim_partial, 0, stream>>>(
439+
workspace_ptr,
440+
partial_workspace,
441+
rows, cols);
442+
443+
reduce_dbias<DBiasTypeOut>(partial_workspace, dbias, partial_rows, cols, stream);
444+
}
445+
446+
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
447+
float (*OP)(float, const ParamOP &)>
448+
void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tensor *noop,
449+
Tensor *output, Tensor *dbias, Tensor *workspace,
450+
cudaStream_t stream) {
451+
switch (output->scaling_mode) {
452+
case NVTE_DELAYED_TENSOR_SCALING: {
453+
const size_t rows = input.flat_first_dim();
454+
const size_t cols = input.flat_last_dim();
455+
456+
if constexpr (IS_DBIAS) {
457+
NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true.");
458+
NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true.");
459+
if (workspace->data.dptr == nullptr) {
460+
if constexpr (IS_DACT) {
461+
const size_t partial_rows = DIVUP(rows, TILE_DIM);
462+
size_t total_elements = (rows * cols) + (partial_rows * cols);
463+
workspace->data.shape = {total_elements};
464+
workspace->data.dtype = DType::kFloat32;
465+
} else {
466+
workspace->data.shape = {rows, cols};
467+
workspace->data.dtype = DType::kFloat32;
468+
}
469+
return;
470+
}
471+
472+
const void *ptr_to_reduce = nullptr;
473+
DType dtype_to_reduce;
474+
475+
workspace->amax = {};
476+
workspace->scale = {};
477+
workspace->scale_inv = {};
478+
479+
Tensor workspace_buffer;
480+
Tensor partial_sum_buffer;
481+
482+
if constexpr (IS_DACT) {
483+
// The values to reduce are the result of the dAct function.
484+
NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT.");
485+
486+
const size_t partial_rows = DIVUP(rows, TILE_DIM);
487+
const size_t full_size_bytes = rows * cols * sizeof(float);
488+
workspace_buffer = *workspace;
489+
workspace_buffer.data.shape = {rows, cols};
490+
partial_sum_buffer.data.dptr = reinterpret_cast<char*>(workspace->data.dptr) + full_size_bytes;
491+
partial_sum_buffer.data.shape = {partial_rows, cols};
492+
partial_sum_buffer.data.dtype = DType::kFloat32;
493+
workspace = &partial_sum_buffer;
494+
495+
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, &workspace_buffer, stream);
496+
if (output && output->data.dptr) {
497+
CastVectorizedUnaryKernelLauncher<transformer_engine::Empty, nullptr>(workspace_buffer, noop, output, stream);
498+
}
499+
ptr_to_reduce = workspace_buffer.data.dptr;
500+
dtype_to_reduce = workspace_buffer.data.dtype;
501+
} else {
502+
if (output && output->data.dptr) {
503+
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
504+
}
505+
// The values to reduce are just the input values.
506+
ptr_to_reduce = input.data.dptr;
507+
dtype_to_reduce = input.data.dtype;
508+
}
509+
510+
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias tensor.");
511+
512+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
513+
dbias->data.dtype, DBiasTypeOut,
514+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
515+
dtype_to_reduce, DTypeReduce,
516+
reduce_dbias_rocm<DTypeReduce, DBiasTypeOut>(
517+
reinterpret_cast<const DTypeReduce *>(ptr_to_reduce),
518+
dbias, rows, cols, stream, workspace);
519+
);
520+
);
521+
} else {
522+
if (output && output->data.dptr) {
523+
if constexpr (IS_DACT) {
524+
NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output.");
525+
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
526+
} else {
527+
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
528+
}
529+
}
530+
}
531+
break;
532+
}
533+
case NVTE_MXFP8_1D_SCALING: {
534+
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, dbias,
535+
workspace, stream);
536+
break;
537+
}
538+
default:
539+
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
540+
}
541+
}
542+
384543
} // namespace transformer_engine

0 commit comments

Comments
 (0)