Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions tests/cpp/operator/test_normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return;
}

#ifndef __HIP_PLATFORM_AMD__
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
}
#endif

using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
Expand Down Expand Up @@ -112,7 +114,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon,
z.data(), mu.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);

nvte_layernorm_bwd(dz.data(), input.data(),
mu.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), dbeta.data(),
Expand Down Expand Up @@ -215,10 +216,21 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
}

std::vector<std::pair<size_t, size_t>> test_cases = {
{71, 229},
{29, 541},
{768, 6144},
// {71, 229},
// {29, 541},
// {768, 6144},
{2048, 12288},
{768,1024},
{256,65536},
{128,6144},
{64,2304},
{229,541},
{71, 3571},
{29,17389},
{76800,1600}
// {512,768},
// {71,3571},
// {168,184}
};

} // namespace
Expand Down
45 changes: 42 additions & 3 deletions tests/cpp/operator/test_normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,45 @@ void compute_ref_stats(NormType norm_type,
}
}

// template <typename InputType>
// inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {

// using compute_t = float;

// // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// // Remove the use_cudnn check here when it is supported by both backends.
// const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;

// if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
// compute_t g = static_cast<compute_t>(gamma);
// if (zero_centered_gamma) {
// g += static_cast<compute_t>(1.f);
// }
// return g;
// } else {
// if (zero_centered_gamma_in_weight_dtype){
// compute_t g = static_cast<compute_t>(0.f);
// InputType gi = gamma;
// if (zero_centered_gamma) {
// gi = gi + static_cast<InputType>(1.f);
// }
// g = static_cast<compute_t>(gi);
// return g;
// } else {
// compute_t g = static_cast<compute_t>(gamma);
// if (zero_centered_gamma) {
// g += static_cast<compute_t>(1.f);
// }
// return g;
// }
// }
// }

template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype){

using compute_t = float;

// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
Expand All @@ -80,6 +114,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
}
return g;
} else {
#ifdef __HIP_PLATFORM_AMD__
(void)zero_centered_gamma_in_weight_dtype; // Parameter is unused on AMD platform
#else
if (zero_centered_gamma_in_weight_dtype){
compute_t g = static_cast<compute_t>(0.f);
InputType gi = gamma;
Expand All @@ -88,7 +125,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
}
g = static_cast<compute_t>(gi);
return g;
} else {
} else
#endif
{
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/normalization/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct Kernel_traits_finalize : public Base {
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
typename StatsT = transformer_engine::Stats<compute_t_, CTAS_PER_ROW_, WARPS_M_, WARPS_N_>,
typename Base =
Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_, index_t_,
WARPS_M_ * WARPS_N_ * THREADS_PER_WARP> >
Expand Down Expand Up @@ -120,7 +121,7 @@ struct Kernel_traits : public Base {
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
// static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");

using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
using Stats = StatsT;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,4 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp32, fp32, fp32, fp3
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "../common.h"
#include "../kernel_traits.h"
#include "ln_fwd_kernels.cuh"

#include <iostream>
using namespace transformer_engine::normalization;

template <typename weight_t, typename input_t, typename output_t, typename compute_t,
Expand Down Expand Up @@ -73,7 +73,8 @@ template <typename weight_t, typename input_t, typename output_t, typename compu
static void launch_general_(LaunchParams<ForwardKernelParams> &launch_params,
const bool configure_params) { // NOLINT(*)
using Kernel_traits = Kernel_traits<weight_t, input_t, output_t, compute_t, index_t, HIDDEN_SIZE,
1, WARPS_M, WARPS_N, BYTES_PER_LDG>;
1, WARPS_M, WARPS_N, BYTES_PER_LDG,
transformer_engine::Stats_ge<compute_t, 1, WARPS_M, WARPS_N,int>>;
auto kernel = &ln_fwd_general_kernel<Kernel_traits>;
auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; };

Expand Down Expand Up @@ -218,13 +219,13 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 1, 2, 8);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 2, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
Expand All @@ -242,7 +243,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32,
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16);
Expand Down Expand Up @@ -272,7 +273,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, bf16, fp32,
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16);
Expand All @@ -290,7 +291,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, bf16, fp32,
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 1, 1, 16, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16);
Expand Down Expand Up @@ -362,7 +363,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, bf16, fp32,
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 2, 1, 8, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16);
Expand All @@ -379,7 +380,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, f
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16);
Expand All @@ -401,26 +402,32 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, bf16, fp32,
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 2, 4, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16);

// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16);
// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16);
// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp32, fp32, fp16, fp32, 1, 4, 16);
// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16);
// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp32, fp32, bf16, fp32, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 20480, fp16, fp16, fp16, fp32, 1, 16, 16);

#ifdef __HIP_PLATFORM_AMD__
// ROCM uses TE normalization for e5m2

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16);
Expand Down
Loading