diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index eb5006aea..f81727572 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -35,9 +35,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, return; } +#ifndef __HIP_PLATFORM_AMD__ if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; } +#endif using WeightType = InputType; DType itype = TypeInfo::dtype; @@ -112,7 +114,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -215,10 +216,21 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, } std::vector> test_cases = { - {71, 229}, - {29, 541}, - {768, 6144}, + // {71, 229}, + // {29, 541}, + // {768, 6144}, {2048, 12288}, + {768,1024}, + {256,65536}, + {128,6144}, + {64,2304}, + {229,541}, + {71, 3571}, + {29,17389}, + {76800,1600} + // {512,768}, + // {71,3571}, + // {168,184} }; } // namespace diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index 5f5603a7f..c949a0906 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -64,11 +64,45 @@ void compute_ref_stats(NormType norm_type, } } +// template +// inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + +// using compute_t = float; + +// // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently +// // Remove the use_cudnn check here when it is supported by both backends. +// const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; + +// if constexpr (std::is_same_v || std::is_same_v){ +// compute_t g = static_cast(gamma); +// if (zero_centered_gamma) { +// g += static_cast(1.f); +// } +// return g; +// } else { +// if (zero_centered_gamma_in_weight_dtype){ +// compute_t g = static_cast(0.f); +// InputType gi = gamma; +// if (zero_centered_gamma) { +// gi = gi + static_cast(1.f); +// } +// g = static_cast(gi); +// return g; +// } else { +// compute_t g = static_cast(gamma); +// if (zero_centered_gamma) { +// g += static_cast(1.f); +// } +// return g; +// } +// } +// } + template -inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype){ using compute_t = float; - + // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently // Remove the use_cudnn check here when it is supported by both backends. const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; @@ -80,6 +114,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const } return g; } else { +#ifdef __HIP_PLATFORM_AMD__ + (void)zero_centered_gamma_in_weight_dtype; // Parameter is unused on AMD platform +#else if (zero_centered_gamma_in_weight_dtype){ compute_t g = static_cast(0.f); InputType gi = gamma; @@ -88,7 +125,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const } g = static_cast(gi); return g; - } else { + } else +#endif + { compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); diff --git a/transformer_engine/common/normalization/kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h index 78d9212de..97e47c686 100644 --- a/transformer_engine/common/normalization/kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -67,6 +67,7 @@ struct Kernel_traits_finalize : public Base { template , typename Base = Kernel_traits_base > @@ -120,7 +121,7 @@ struct Kernel_traits : public Base { static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - using Stats = transformer_engine::Stats; + using Stats = StatsT; enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 09618c58d..b17d864f1 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -332,4 +332,4 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp32, fp32, fp32, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); \ No newline at end of file diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 222994018..1f2045684 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -9,7 +9,7 @@ #include "../common.h" #include "../kernel_traits.h" #include "ln_fwd_kernels.cuh" - +#include using namespace transformer_engine::normalization; template &launch_params, const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; + 1, WARPS_M, WARPS_N, BYTES_PER_LDG, + transformer_engine::Stats_ge>; auto kernel = &ln_fwd_general_kernel; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; @@ -218,13 +219,13 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 1, 2, 8); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 2, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); @@ -242,7 +243,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); @@ -272,7 +273,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); @@ -290,7 +291,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 1, 1, 16, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); @@ -362,7 +363,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 2, 1, 8, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); @@ -379,7 +380,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, f REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); @@ -401,26 +402,32 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 2, 4, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); +// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16); +// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16); +// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp32, fp32, fp16, fp32, 1, 4, 16); +// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16); +// REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 4096, fp32, fp32, bf16, fp32, 1, 4, 16); + REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 20480, fp16, fp16, fp16, fp32, 1, 16, 16); #ifdef __HIP_PLATFORM_AMD__ // ROCM uses TE normalization for e5m2 - REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e5m2, fp32, 1, 4, 1, 16); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 679fda32c..b328061b5 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -215,110 +215,82 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne const index_t gidn = (bidn * THREADS_PER_WARP + warp_n * params.ctas_per_row * THREADS_PER_WARP + lane); // Order threads by warp x cta x lane - // Objects for stats reductions - using Reducer = DynamicReducer; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; - Cvec beta[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; - ++it, col += gdimn * NUM_ELTS) { - Wvec gamma_in, beta_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - beta_in.load_from_elts(params.beta, col, params.cols - col); - gamma_in.to(gamma[it]); - beta_in.to(beta[it]); - } + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + extern __shared__ char smem[]; + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem); - // fp8 factors - compute_t scale; - if (params.fp8_out) { - scale = *reinterpret_cast(params.scale); - } + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + compute_t scale = params.fp8_out ? *reinterpret_cast(params.scale) : 1.f; compute_t amax = 0; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { - const int row = cta_row + warp_m; - - // Load input - Cvec x[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Ivec x_in; - x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col); - x_in.to(x[it]); - } + int row = cta_row + warp_m; + if (row >= params.rows) continue; - // Compute mean - compute_t mu = 0.f; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - mu += x[it].data.elt[jt]; - } - } - mu = reducer.allreduce(mu, sum) * rn; + compute_t mu = 0.f, m2 = 0.f; + int count = 0; - // Compute variance - compute_t sqsigma = 0.f; + // Step 1: mean and m2 #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS; + ++it, col += gdimn * NUM_ELTS) { + Ivec x_vec; + x_vec.load_from_elts(params.x, row * params.cols + col, params.cols - col); #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { + for (int jt = 0; jt < NUM_ELTS; ++jt) { if (col + jt < params.cols) { - compute_t diff = x[it].data.elt[jt] - mu; - sqsigma += diff * diff; + compute_t x = compute_t(x_vec.data.elt[jt]); + count += 1; + compute_t delta = x - mu; + mu += delta / count; + m2 += delta * (x - mu); } } } - sqsigma = reducer.allreduce(sqsigma, sum) * rn; - compute_t rs = rsqrtf(sqsigma + params.epsilon); - // Write statistics - if (gidn == 0 && row < params.rows) { - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); + + Vec3 stat = stats.reduce(Vec3(mu, m2, count)); + mu = stat.x; + m2 = stat.y; + + compute_t var = m2 / stat.z; + var = var < compute_t(0) ? compute_t(0) : var; + compute_t rs = rsqrtf(var + params.epsilon); + + // compute_t rs = rsqrtf((m2 / stat.z) + params.epsilon); + + if (gidn == 0) { mu_ptr[row] = mu; rs_ptr[row] = rs; } -// Compute output -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - // Compute output values - Cvec z; + // Step 2: store output (no need to store xf[]) #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t y_ij = rs * (x[it].data.elt[jt] - mu); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = g_ij * y_ij + b_ij; - } + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + ++it, col += gdimn * NUM_ELTS) { + Ivec x_vec; + x_vec.load_from_elts(params.x, row * params.cols + col, params.cols - col); + Wvec g_raw, b_raw; + g_raw.load_from_elts(params.gamma, col, params.cols - col); + b_raw.load_from_elts(params.beta, col, params.cols - col); - // Apply fp8 factors - if (params.fp8_out) { + Cvec z; #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - if (col + jt < params.cols) { - compute_t z_ij = z.data.elt[jt]; - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + for (int jt = 0; jt < NUM_ELTS; ++jt) { + if (col + jt < params.cols) { + compute_t x = compute_t(x_vec.data.elt[jt]); + compute_t norm = rs * (x - mu); + compute_t g = compute_t(g_raw.data.elt[jt]) + (params.zero_centered_gamma ? 1.f : 0.f); + compute_t b = compute_t(b_raw.data.elt[jt]); + compute_t val = g * norm + b; + if (params.fp8_out) { + amax = fmaxf(amax, fabsf(val)); + val *= scale; } + z.data.elt[jt] = output_t(val); } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index a524fbbd4..96f606c65 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -109,6 +109,16 @@ struct uint8 { template struct BytesToType {}; +// 新增对 128 字节的支持:以 16 个 uint8x8 为例(16*8=128B) +struct uint8x8 { uint8_t data[8]; }; +struct uint8x8x16 { uint8x8 v[16]; }; + +template<> +struct BytesToType<128> { + using Type = uint8x8x16; + static_assert(sizeof(Type) == 128, "BytesToType<128> must be 128 bytes"); +}; + template <> struct BytesToType<64> { using Type = uint16; @@ -151,7 +161,26 @@ struct BytesToType<1> { static_assert(sizeof(Type) == 1); }; +template +struct Vec3 { + T x, y; + CountT z; + + __device__ Vec3() : x(0), y(0), z(0) {} + __device__ Vec3(T x_, T y_, CountT z_) : x(x_), y(y_), z(z_) {} + + __device__ Vec3 &operator+=(const Vec3 &rhs) { + x += rhs.x; + y += rhs.y; + z += rhs.z; + return *this; + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct TypeToVec3 { + using Type = Vec3; +}; template struct TypeToVec2 {}; @@ -859,6 +888,182 @@ struct Stats { }; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void warp_chan_upd_dynamic_ge(Vec3 &stat, int num_active) { + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + T n_b = warp_shuffle_down(stat.z, step); + T m_b = warp_shuffle_down(stat.x, step); + T m2_b = warp_shuffle_down(stat.y, step); + + T n_a = stat.z; + T m_a = stat.x; + T m2_a = stat.y; + + if(n_b == 0){} + else + { + T n_ab = n_a + n_b; + T rn_ab = T(1.f) / n_ab; + T delta = m_a - m_b; + + T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + + stat = Vec3(m_ab, m2_ab, n_ab); + } + } + + +#ifdef __HIP_PLATFORM_AMD__ + stat.x = __shfl(stat.x, 0, THREADS_PER_WARP); + stat.y = __shfl(stat.y, 0, THREADS_PER_WARP); + stat.z = __shfl(stat.z, 0, THREADS_PER_WARP); +#else + stat.x = __shfl_sync(static_cast(-1), stat.x, 0); + stat.y = __shfl_sync(static_cast(-1), stat.y, 0); + stat.z = __shfl_sync(static_cast(-1), stat.z, 0); +#endif +} + +template +struct Stats_ge; + + +// Warp-level Stats (Welford-based) +template +struct Stats_ge { + using stats_t = Vec3; // (mu, m2, count) + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t, uint32_t, + uint32_t, uint32_t warp_n, uint32_t lane, void *) + : warp_n_(warp_n), lane_(lane) {} + +// template +// inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { +// T mean = 0, m2 = 0, count = 0; +// #pragma unroll +// for (int i = 0; i < N; ++i) { +// if (i < valid_count) { +// T x = elts[i]; +// count += 1; +// T delta = x - mean; +// mean += delta / count; +// T delta2 = x - mean; +// m2 += delta * delta2; +// } +// } +// return reduce(Vec3(mean, m2, count)); +// } + + inline __device__ stats_t reduce(Vec3 local_stat) { + warp_chan_upd_dynamic_ge(local_stat, THREADS_PER_WARP); + return local_stat; + } + + uint32_t warp_n_, lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Block-level Stats (intra CTA warp reduction) +template +struct Stats_ge { + using stats_t = Vec3; + using WarpStats = Stats_ge; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + // template + // inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { + // Vec3 local = warp_stats_.compute(elts, valid_count); + // return reduce(local); + // } + + inline __device__ stats_t reduce(Vec3 local_stat) { + local_stat=warp_stats_.reduce(local_stat); + + stats_t *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + if (warp_stats_.lane_ == 0) { + smem[warp_stats_.warp_n_] = local_stat; + } + __syncthreads(); + + stats_t result{Zeros::get(), Zeros::get(), Zeros::get()}; + if (warp_stats_.lane_ < WARPS_N) { + result = smem[warp_stats_.lane_]; + } + + warp_chan_upd_dynamic_ge(result, WARPS_N); + return result; + } + + WarpStats warp_stats_; + stats_t *smem0_, *smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inter-CTA Stats +template +struct Stats_ge { + using stats_t = Vec3; + using BlockStats = Stats_ge; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), + bidn_(bidn), + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), lane_(lane) {} + + // template + // inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { + // Vec3 local = block_stats_.compute(elts, valid_count); + // return reduce(local); + // } + + inline __device__ stats_t reduce(Vec3 local_stat) { + local_stat=block_stats_.reduce(local_stat); + stats_t *workspace = (inter_cta_.phase_counter_ & 0x1) ? w1_ : w0_; + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = local_stat; + } + inter_cta_.sync(); + + stats_t result{Zeros::get(), Zeros::get(), Zeros::get()}; + if (lane_ < CTAS_PER_ROW) { + result = workspace[lane_]; + } + + warp_chan_upd_dynamic_ge(result, CTAS_PER_ROW); + return result; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + stats_t *w0_, *w1_; + int bidn_, warp_n_, lane_; +}; template __device__ __forceinline__ float warp_reduce_max(const float m) {