Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions tests/cpp/operator/test_normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,20 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
}

std::vector<std::pair<size_t, size_t>> test_cases = {
{71, 229},
{29, 541},
{768, 6144},
{2048, 12288},
// {71, 229},
// {29, 541},
// {768, 6144},
//{2048, 12288},
//{71,3571}
//{168,184}
// {768,1024},
// {256,65536},
// {128,6144},
// {64,2304},
// {229,541},
// {71, 3571},
{512,768}
//{76800,1600}
};

} // namespace
Expand Down
45 changes: 42 additions & 3 deletions tests/cpp/operator/test_normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,45 @@ void compute_ref_stats(NormType norm_type,
}
}

// template <typename InputType>
// inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {

// using compute_t = float;

// // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// // Remove the use_cudnn check here when it is supported by both backends.
// const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;

// if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
// compute_t g = static_cast<compute_t>(gamma);
// if (zero_centered_gamma) {
// g += static_cast<compute_t>(1.f);
// }
// return g;
// } else {
// if (zero_centered_gamma_in_weight_dtype){
// compute_t g = static_cast<compute_t>(0.f);
// InputType gi = gamma;
// if (zero_centered_gamma) {
// gi = gi + static_cast<InputType>(1.f);
// }
// g = static_cast<compute_t>(gi);
// return g;
// } else {
// compute_t g = static_cast<compute_t>(gamma);
// if (zero_centered_gamma) {
// g += static_cast<compute_t>(1.f);
// }
// return g;
// }
// }
// }

template <typename InputType>
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) {
inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype){

using compute_t = float;

// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
Expand All @@ -80,6 +114,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
}
return g;
} else {
#ifdef __HIP_PLATFORM_AMD__
(void)zero_centered_gamma_in_weight_dtype; // Parameter is unused on AMD platform
#else
if (zero_centered_gamma_in_weight_dtype){
compute_t g = static_cast<compute_t>(0.f);
InputType gi = gamma;
Expand All @@ -88,7 +125,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
}
g = static_cast<compute_t>(gi);
return g;
} else {
} else
#endif
{
compute_t g = static_cast<compute_t>(gamma);
if (zero_centered_gamma) {
g += static_cast<compute_t>(1.f);
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/normalization/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ struct Kernel_traits_finalize : public Base {
template <typename weight_t_, typename input_t_, typename output_t_, typename compute_t_,
typename index_t_, uint32_t HIDDEN_SIZE_, uint32_t CTAS_PER_ROW_, uint32_t WARPS_M_,
uint32_t WARPS_N_, uint32_t BYTES_PER_LDG_ = 16,
typename StatsT = transformer_engine::Stats<compute_t_, CTAS_PER_ROW_, WARPS_M_, WARPS_N_>,
typename Base =
Kernel_traits_base<HIDDEN_SIZE_, weight_t_, input_t_, output_t_, compute_t_, index_t_,
WARPS_M_ * WARPS_N_ * THREADS_PER_WARP> >
Expand Down Expand Up @@ -120,7 +121,7 @@ struct Kernel_traits : public Base {
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
// static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");

using Stats = transformer_engine::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
using Stats = StatsT;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,15 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz

const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS;
const uint32_t COL_STRIDE = params.cols * THREADS_PER_WARP;
for (uint32_t col = c, col_out = c_out; col < params.cols;
col += COL_STRIDE, col_out += COL_STRIDE / 2) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) {
index_t idx = row * Kernel_traits::COLS + col;

index_t idx = row * params.cols + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
Expand Down Expand Up @@ -391,7 +390,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne
}

Cvec dy[LDGS];
Cvec y[LDGS];
//Cvec y[LDGS];
compute_t mdy = 0.f;
compute_t mdyy = 0.f;

Expand All @@ -411,14 +410,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne
const compute_t dz_ij = dz.data.elt[jt];
const compute_t dy_ij = g_ij * dz_ij;

y[it].data.elt[jt] = y_ij;
//y[it].data.elt[jt] = y_ij;
dy[it].data.elt[jt] = dy_ij;

mdy += dy_ij;
mdyy += dy_ij * y_ij;

dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
// dz_sum[it].data.elt[jt] += dz_ij;
// dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
}

Expand All @@ -432,11 +431,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne
for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols;
it++, col += gdimn * NUM_ELTS) {
Ivec dx;

Ivec x;
Ovec dz;
x.load_from_elts(params.x, row * params.cols + col, params.cols - col);
dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col);

#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
compute_t dy_ij = dy[it].data.elt[jt];
compute_t y_ij = y[it].data.elt[jt];
dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy));
const compute_t x_ij = x.data.elt[jt];
const compute_t y_ij = rs * (x_ij - mu);
const compute_t dz_ij = dz.data.elt[jt];

dx.data.elt[jt] = rs * (dy[it].data.elt[jt] - (mdyy * y_ij + mdy));

dz_sum[it].data.elt[jt] += dz_ij;
dzy_sum[it].data.elt[jt] += dz_ij * y_ij;
}
dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "../common.h"
#include "../kernel_traits.h"
#include "ln_bwd_kernels.cuh"
#include <iostream>

using namespace transformer_engine::normalization;

Expand Down Expand Up @@ -39,7 +40,9 @@ static void launch_tuned_(LaunchParams<BackwardKernelParams> &launch_params,
launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t);
return;
}

// std::cout<<"bwd ctas_per_row:"<< CTAS_PER_ROW<<std::endl;
// std::cout<<"bwd warps_n:"<< WARPS_N<<std::endl;
// std::cout<<"bwd bytes_per_load:"<<BYTES_PER_LDG_MAIN <<std::endl;
#ifndef __HIP_PLATFORM_AMD__
if (Kernel_traits::SMEM_BYTES >= 48 * 1024) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down Expand Up @@ -106,7 +109,10 @@ static void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t);
return;
}

// std::cout<<"bwd cols:"<<Kernel_traits::COLS<<std::endl;
// std::cout<<"warps_m:"<<WARPS_M<<std::endl;
// std::cout<<"warps_n:"<<WARPS_N<<std::endl;
// std::cout<<"bytes_per_load:"<<BYTES_PER_LDG_MAIN<<std::endl;
// Launch kernel
auto stream = launch_params.stream;
dim3 grid(ctas_per_row * ctas_per_col);
Expand All @@ -119,17 +125,40 @@ static void launch_general_(LaunchParams<BackwardKernelParams> &launch_params,
reinterpret_cast<void **>(&params_), 0, stream);
}

// Launch finalization kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t));
auto kernel_final =
&ln_bwd_finalize_general_kernel<weight_t, compute_t, WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL, Kernel_traits::THREADS_PER_WARP>;
dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);
kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
// Decide which finalize kernel to launch based on column alignment
const bool cols_aligned = (cols % 32 == 0);

if (cols_aligned) {
// Launch tuned finalize kernel
using Kernel_traits_f = Kernel_traits_finalize<HIDDEN_SIZE, weight_t, input_t, output_t,
compute_t, index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;

auto kernel_f = &ln_bwd_finalize_tuned_kernel<Kernel_traits_f>;


kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(
launch_params.params);

} else {
// Launch general finalize kernel
constexpr uint32_t WARPS_M_FINAL = 4;
constexpr uint32_t WARPS_N_FINAL = 1;
constexpr uint32_t ELTS_N_PER_CTA_FINAL =
(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL) /
sizeof(compute_t);

auto kernel_final = &ln_bwd_finalize_general_kernel<weight_t, compute_t,
WARPS_M_FINAL, WARPS_N_FINAL,
BYTES_PER_LDG_FINAL,
Kernel_traits::THREADS_PER_WARP>;

dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL);
dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1);

kernel_final<<<grid_final, block_final, 0, stream>>>(launch_params.params);
}
}

#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \
Expand Down Expand Up @@ -157,19 +186,19 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32,
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 8);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);//
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 2, 1, 1, 8, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
Expand Down Expand Up @@ -223,7 +252,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 1, 1, 16, 8, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
Expand Down Expand Up @@ -295,7 +324,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 4, 1, 16, 8, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
Expand All @@ -317,13 +346,13 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 2, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 32, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4);
Expand Down
Loading