diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index eb5006aea..5c19a79da 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -215,10 +215,20 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, } std::vector> test_cases = { - {71, 229}, - {29, 541}, - {768, 6144}, - {2048, 12288}, + // {71, 229}, + // {29, 541}, + // {768, 6144}, + //{2048, 12288}, + //{71,3571} + //{168,184} + // {768,1024}, + // {256,65536}, + // {128,6144}, + // {64,2304}, + // {229,541}, + // {71, 3571}, + {512,768} + //{76800,1600} }; } // namespace diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index 5f5603a7f..c949a0906 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -64,11 +64,45 @@ void compute_ref_stats(NormType norm_type, } } +// template +// inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + +// using compute_t = float; + +// // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently +// // Remove the use_cudnn check here when it is supported by both backends. +// const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; + +// if constexpr (std::is_same_v || std::is_same_v){ +// compute_t g = static_cast(gamma); +// if (zero_centered_gamma) { +// g += static_cast(1.f); +// } +// return g; +// } else { +// if (zero_centered_gamma_in_weight_dtype){ +// compute_t g = static_cast(0.f); +// InputType gi = gamma; +// if (zero_centered_gamma) { +// gi = gi + static_cast(1.f); +// } +// g = static_cast(gi); +// return g; +// } else { +// compute_t g = static_cast(gamma); +// if (zero_centered_gamma) { +// g += static_cast(1.f); +// } +// return g; +// } +// } +// } + template -inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype){ using compute_t = float; - + // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently // Remove the use_cudnn check here when it is supported by both backends. const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; @@ -80,6 +114,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const } return g; } else { +#ifdef __HIP_PLATFORM_AMD__ + (void)zero_centered_gamma_in_weight_dtype; // Parameter is unused on AMD platform +#else if (zero_centered_gamma_in_weight_dtype){ compute_t g = static_cast(0.f); InputType gi = gamma; @@ -88,7 +125,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const } g = static_cast(gi); return g; - } else { + } else +#endif + { compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); diff --git a/transformer_engine/common/normalization/kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h index 78d9212de..97e47c686 100644 --- a/transformer_engine/common/normalization/kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -67,6 +67,7 @@ struct Kernel_traits_finalize : public Base { template , typename Base = Kernel_traits_base > @@ -120,7 +121,7 @@ struct Kernel_traits : public Base { static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - using Stats = transformer_engine::Stats; + using Stats = StatsT; enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index a13976e6f..1c5d95744 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -227,16 +227,15 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; + const uint32_t COL_STRIDE = params.cols * THREADS_PER_WARP; + for (uint32_t col = c, col_out = c_out; col < params.cols; col += COL_STRIDE, col_out += COL_STRIDE / 2) { // Each thread sums over NUM_ELT columns. Vec dbeta_local, dgamma_local; memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dbeta_local, 0, sizeof(dbeta_local)); for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { - index_t idx = row * Kernel_traits::COLS + col; - + index_t idx = row * params.cols + col; Vec dbeta_part, dgamma_part; dbeta_part.load_from(params.dbeta_part, idx); dgamma_part.load_from(params.dgamma_part, idx); @@ -391,7 +390,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne } Cvec dy[LDGS]; - Cvec y[LDGS]; + //Cvec y[LDGS]; compute_t mdy = 0.f; compute_t mdyy = 0.f; @@ -411,14 +410,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne const compute_t dz_ij = dz.data.elt[jt]; const compute_t dy_ij = g_ij * dz_ij; - y[it].data.elt[jt] = y_ij; + //y[it].data.elt[jt] = y_ij; dy[it].data.elt[jt] = dy_ij; mdy += dy_ij; mdyy += dy_ij * y_ij; - dz_sum[it].data.elt[jt] += dz_ij; - dzy_sum[it].data.elt[jt] += dz_ij * y_ij; + // dz_sum[it].data.elt[jt] += dz_ij; + // dzy_sum[it].data.elt[jt] += dz_ij * y_ij; } } @@ -432,11 +431,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; it++, col += gdimn * NUM_ELTS) { Ivec dx; + + Ivec x; + Ovec dz; + x.load_from_elts(params.x, row * params.cols + col, params.cols - col); + dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); + #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t dy_ij = dy[it].data.elt[jt]; - compute_t y_ij = y[it].data.elt[jt]; - dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy)); + const compute_t x_ij = x.data.elt[jt]; + const compute_t y_ij = rs * (x_ij - mu); + const compute_t dz_ij = dz.data.elt[jt]; + + dx.data.elt[jt] = rs * (dy[it].data.elt[jt] - (mdyy * y_ij + mdy)); + + dz_sum[it].data.elt[jt] += dz_ij; + dzy_sum[it].data.elt[jt] += dz_ij * y_ij; } dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); } diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 09618c58d..74d62468c 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -10,6 +10,7 @@ #include "../common.h" #include "../kernel_traits.h" #include "ln_bwd_kernels.cuh" +#include using namespace transformer_engine::normalization; @@ -39,7 +40,9 @@ static void launch_tuned_(LaunchParams &launch_params, launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); return; } - + // std::cout<<"bwd ctas_per_row:"<< CTAS_PER_ROW<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -106,7 +109,10 @@ static void launch_general_(LaunchParams &launch_params, launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t); return; } - + // std::cout<<"bwd cols:"< &launch_params, reinterpret_cast(¶ms_), 0, stream); } - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &ln_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); + // Decide which finalize kernel to launch based on column alignment + const bool cols_aligned = (cols % 32 == 0); + + if (cols_aligned) { + // Launch tuned finalize kernel + using Kernel_traits_f = Kernel_traits_finalize; + + auto kernel_f = &ln_bwd_finalize_tuned_kernel; + + + kernel_f<<>>( + launch_params.params); + + } else { + // Launch general finalize kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL) / + sizeof(compute_t); + + auto kernel_final = &ln_bwd_finalize_general_kernel; + + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + + kernel_final<<>>(launch_params.params); + } } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -157,7 +186,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 8); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); @@ -165,11 +194,11 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);// REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 2, 1, 1, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); @@ -223,7 +252,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 1, 1, 16, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); @@ -295,7 +324,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 4, 1, 16, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); @@ -317,13 +346,13 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 2, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 32, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 222994018..507bfcab0 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -9,7 +9,7 @@ #include "../common.h" #include "../kernel_traits.h" #include "ln_fwd_kernels.cuh" - +#include using namespace transformer_engine::normalization; template &launch_params, #endif return; } - + std::cout<<"tuned fwd ctas_per_row:"<< CTAS_PER_ROW<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -73,7 +76,8 @@ template &launch_params, const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; + 1, WARPS_M, WARPS_N, BYTES_PER_LDG, + transformer_engine::Stats_ge>; auto kernel = &ln_fwd_general_kernel; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; @@ -103,7 +107,9 @@ static void launch_general_(LaunchParams &launch_params, #endif return; } - + // std::cout<<"warps_m:"<; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; - Cvec beta[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; - ++it, col += gdimn * NUM_ELTS) { - Wvec gamma_in, beta_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - beta_in.load_from_elts(params.beta, col, params.cols - col); - gamma_in.to(gamma[it]); - beta_in.to(beta[it]); - } + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + extern __shared__ char smem[]; + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem); - // fp8 factors - compute_t scale; - if (params.fp8_out) { - scale = *reinterpret_cast(params.scale); - } + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + compute_t scale = params.fp8_out ? *reinterpret_cast(params.scale) : 1.f; compute_t amax = 0; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { - const int row = cta_row + warp_m; - - // Load input - Cvec x[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Ivec x_in; - x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col); - x_in.to(x[it]); - } + int row = cta_row + warp_m; + if (row >= params.rows) continue; - // Compute mean - compute_t mu = 0.f; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - mu += x[it].data.elt[jt]; - } - } - mu = reducer.allreduce(mu, sum) * rn; + compute_t mu = 0.f, m2 = 0.f; + int count = 0; - // Compute variance - compute_t sqsigma = 0.f; + // Step 1: mean and m2 #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS; + ++it, col += gdimn * NUM_ELTS) { + Ivec x_vec; + x_vec.load_from_elts(params.x, row * params.cols + col, params.cols - col); #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { + for (int jt = 0; jt < NUM_ELTS; ++jt) { if (col + jt < params.cols) { - compute_t diff = x[it].data.elt[jt] - mu; - sqsigma += diff * diff; + compute_t x = compute_t(x_vec.data.elt[jt]); + count += 1; + compute_t delta = x - mu; + mu += delta / count; + m2 += delta * (x - mu); } } } - sqsigma = reducer.allreduce(sqsigma, sum) * rn; - compute_t rs = rsqrtf(sqsigma + params.epsilon); - // Write statistics - if (gidn == 0 && row < params.rows) { - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); + + Vec3 stat = stats.reduce(Vec3(mu, m2, count)); + mu = stat.x; + m2 = stat.y; + + compute_t var = m2 / stat.z; + var = var < compute_t(0) ? compute_t(0) : var; + compute_t rs = rsqrtf(var + params.epsilon); + + // compute_t rs = rsqrtf((m2 / stat.z) + params.epsilon); + + if (gidn == 0) { mu_ptr[row] = mu; rs_ptr[row] = rs; } -// Compute output -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - // Compute output values - Cvec z; + // Step 2: store output (no need to store xf[]) #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t y_ij = rs * (x[it].data.elt[jt] - mu); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = g_ij * y_ij + b_ij; - } + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + ++it, col += gdimn * NUM_ELTS) { + Ivec x_vec; + x_vec.load_from_elts(params.x, row * params.cols + col, params.cols - col); + Wvec g_raw, b_raw; + g_raw.load_from_elts(params.gamma, col, params.cols - col); + b_raw.load_from_elts(params.beta, col, params.cols - col); - // Apply fp8 factors - if (params.fp8_out) { + Cvec z; #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - if (col + jt < params.cols) { - compute_t z_ij = z.data.elt[jt]; - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + for (int jt = 0; jt < NUM_ELTS; ++jt) { + if (col + jt < params.cols) { + compute_t x = compute_t(x_vec.data.elt[jt]); + compute_t norm = rs * (x - mu); + compute_t g = compute_t(g_raw.data.elt[jt]) + (params.zero_centered_gamma ? 1.f : 0.f); + compute_t b = compute_t(b_raw.data.elt[jt]); + compute_t val = g * norm + b; + if (params.fp8_out) { + amax = fmaxf(amax, fabsf(val)); + val *= scale; } + z.data.elt[jt] = output_t(val); } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index a524fbbd4..96f606c65 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -109,6 +109,16 @@ struct uint8 { template struct BytesToType {}; +// 新增对 128 字节的支持:以 16 个 uint8x8 为例(16*8=128B) +struct uint8x8 { uint8_t data[8]; }; +struct uint8x8x16 { uint8x8 v[16]; }; + +template<> +struct BytesToType<128> { + using Type = uint8x8x16; + static_assert(sizeof(Type) == 128, "BytesToType<128> must be 128 bytes"); +}; + template <> struct BytesToType<64> { using Type = uint16; @@ -151,7 +161,26 @@ struct BytesToType<1> { static_assert(sizeof(Type) == 1); }; +template +struct Vec3 { + T x, y; + CountT z; + + __device__ Vec3() : x(0), y(0), z(0) {} + __device__ Vec3(T x_, T y_, CountT z_) : x(x_), y(y_), z(z_) {} + + __device__ Vec3 &operator+=(const Vec3 &rhs) { + x += rhs.x; + y += rhs.y; + z += rhs.z; + return *this; + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct TypeToVec3 { + using Type = Vec3; +}; template struct TypeToVec2 {}; @@ -859,6 +888,182 @@ struct Stats { }; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void warp_chan_upd_dynamic_ge(Vec3 &stat, int num_active) { + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + T n_b = warp_shuffle_down(stat.z, step); + T m_b = warp_shuffle_down(stat.x, step); + T m2_b = warp_shuffle_down(stat.y, step); + + T n_a = stat.z; + T m_a = stat.x; + T m2_a = stat.y; + + if(n_b == 0){} + else + { + T n_ab = n_a + n_b; + T rn_ab = T(1.f) / n_ab; + T delta = m_a - m_b; + + T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + + stat = Vec3(m_ab, m2_ab, n_ab); + } + } + + +#ifdef __HIP_PLATFORM_AMD__ + stat.x = __shfl(stat.x, 0, THREADS_PER_WARP); + stat.y = __shfl(stat.y, 0, THREADS_PER_WARP); + stat.z = __shfl(stat.z, 0, THREADS_PER_WARP); +#else + stat.x = __shfl_sync(static_cast(-1), stat.x, 0); + stat.y = __shfl_sync(static_cast(-1), stat.y, 0); + stat.z = __shfl_sync(static_cast(-1), stat.z, 0); +#endif +} + +template +struct Stats_ge; + + +// Warp-level Stats (Welford-based) +template +struct Stats_ge { + using stats_t = Vec3; // (mu, m2, count) + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t, uint32_t, + uint32_t, uint32_t warp_n, uint32_t lane, void *) + : warp_n_(warp_n), lane_(lane) {} + +// template +// inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { +// T mean = 0, m2 = 0, count = 0; +// #pragma unroll +// for (int i = 0; i < N; ++i) { +// if (i < valid_count) { +// T x = elts[i]; +// count += 1; +// T delta = x - mean; +// mean += delta / count; +// T delta2 = x - mean; +// m2 += delta * delta2; +// } +// } +// return reduce(Vec3(mean, m2, count)); +// } + + inline __device__ stats_t reduce(Vec3 local_stat) { + warp_chan_upd_dynamic_ge(local_stat, THREADS_PER_WARP); + return local_stat; + } + + uint32_t warp_n_, lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Block-level Stats (intra CTA warp reduction) +template +struct Stats_ge { + using stats_t = Vec3; + using WarpStats = Stats_ge; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + // template + // inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { + // Vec3 local = warp_stats_.compute(elts, valid_count); + // return reduce(local); + // } + + inline __device__ stats_t reduce(Vec3 local_stat) { + local_stat=warp_stats_.reduce(local_stat); + + stats_t *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + if (warp_stats_.lane_ == 0) { + smem[warp_stats_.warp_n_] = local_stat; + } + __syncthreads(); + + stats_t result{Zeros::get(), Zeros::get(), Zeros::get()}; + if (warp_stats_.lane_ < WARPS_N) { + result = smem[warp_stats_.lane_]; + } + + warp_chan_upd_dynamic_ge(result, WARPS_N); + return result; + } + + WarpStats warp_stats_; + stats_t *smem0_, *smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inter-CTA Stats +template +struct Stats_ge { + using stats_t = Vec3; + using BlockStats = Stats_ge; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), + bidn_(bidn), + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), lane_(lane) {} + + // template + // inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { + // Vec3 local = block_stats_.compute(elts, valid_count); + // return reduce(local); + // } + + inline __device__ stats_t reduce(Vec3 local_stat) { + local_stat=block_stats_.reduce(local_stat); + stats_t *workspace = (inter_cta_.phase_counter_ & 0x1) ? w1_ : w0_; + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = local_stat; + } + inter_cta_.sync(); + + stats_t result{Zeros::get(), Zeros::get(), Zeros::get()}; + if (lane_ < CTAS_PER_ROW) { + result = workspace[lane_]; + } + + warp_chan_upd_dynamic_ge(result, CTAS_PER_ROW); + return result; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + stats_t *w0_, *w1_; + int bidn_, warp_n_, lane_; +}; template __device__ __forceinline__ float warp_reduce_max(const float m) { diff --git a/tuning_tools/abs_do_fwd.sh b/tuning_tools/abs_do_fwd.sh new file mode 100644 index 000000000..55eb2cdd8 --- /dev/null +++ b/tuning_tools/abs_do_fwd.sh @@ -0,0 +1,34 @@ +set -euo pipefail + +cd .. +pip install . + +#等待编译结束 + +cd tests/cpp/build/ +rm -rf * +cmake .. +make + +# 运行 rocprof 并把输出既打印到屏幕又保存到临时文件 +ROCLOG=/tmp/rocprof.log +rocprof --stats ./operator/test_operator | tee "$ROCLOG" + +# 从 rocprof 输出中提取两组数字 Dimension(2048,12288) +shape_line=$(grep -m 1 'OperatorTest/NormTestSuite.TestNorm/LayerNorm_' "$ROCLOG") +dim1=$(awk -F'X' '{print $3}' <<<"$shape_line") +dim2=$(awk -F'X' '{print $4}' <<<"$shape_line") + +# 再提取 ctas_per_row, warps_n, bytes_per_load +ctas=$(grep -m 1 'ctas_per_row:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') +wm=$(grep -m 1 'warps_m:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') +wn=$(grep -m 1 'warps_n:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') +bpl=$(grep -m 1 'bytes_per_load:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') + +# 拼成文件名并创建空文件 +filename="${dim1}_${dim2}_${ctas}_${wm}_${wn}_${bpl}" +# filename="${dim1}_${dim2}_${wm}_${wn}_${bpl}" +touch "/home/tuned_fwd/768/f16f16/$filename" +echo "→ Created file $filename" + +python /home/TransformerEngine/tuning_tools/abs_readall.py "/home/tuned_fwd/768/f16f16/${filename}" \ No newline at end of file diff --git a/tuning_tools/abs_readall.py b/tuning_tools/abs_readall.py new file mode 100644 index 000000000..b8cff57c3 --- /dev/null +++ b/tuning_tools/abs_readall.py @@ -0,0 +1,66 @@ +import json +import os +import sys +import argparse + +def extract_and_process_durations(input_file, output_file, kernel_keywords, num_warmup, num_iteration): + with open(input_file, "r") as f: + data = json.load(f) + + keyword_to_durations = {k: [] for k in kernel_keywords} + + for event in data.get("traceEvents", []): + args = event.get("args", {}) + kernel_name = args.get("KernelName", "") + duration = args.get("DurationNs") + + if duration is not None: + for keyword in kernel_keywords: + if keyword in kernel_name: + keyword_to_durations[keyword].append(int(duration)) + break # 防止同一个event被多个keyword重复统计 + + output_lines = [] + + for keyword in kernel_keywords: + durations = keyword_to_durations[keyword] + output_lines.append(f"== {keyword} ==") + + if not durations: + output_lines.append("[无数据]") + continue + + i = 0 + while i < len(durations): + i += num_warmup # 跳过warmup + batch = [] + for _ in range(num_iteration): + if i < len(durations): + batch.append(durations[i]) + i += 1 + if batch: + avg = sum(batch) / len(batch) + output_lines.append(f"{avg:.2f}") + output_lines.append("") # 空行分隔 + + with open(output_file, "w") as f: + f.write("\n".join(output_lines)) + + print(f"已将所有 kernel 的平均耗时写入 {output_file}") + +input_json = "/home/TransformerEngine/tests/cpp/build/results.json" +if len(sys.argv) > 1: + output_txt = sys.argv[1] +else: + output_txt = "/home/bwdprofiles/tmp/heyi.txt" + +kernel_keywords = [ + "ln_fwd_", + "ln_bwd_general_kernel", + "ln_bwd_finalize" +] + +num_warmup = 5 +num_iteration = 10 + +extract_and_process_durations(input_json, output_txt, kernel_keywords, num_warmup, num_iteration) \ No newline at end of file diff --git a/tuning_tools/find_fast.py b/tuning_tools/find_fast.py new file mode 100644 index 000000000..5ccf12413 --- /dev/null +++ b/tuning_tools/find_fast.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" +脚本:遍历指定目录下所有文件,解析每个文件中 +- ln_fwd_ kernel 的时间之和 +- 将 ln_bwd_tuned_kernel 和 ln_bwd_finalize 两个 kernel 的时间之和合并为一个值 +然后在所有文件中分别找出 ln_fwd_ 和合并后的 bwd 的最小值及对应文件,输出结果。 +""" +import os +import sys +import re + +def parse_file(filepath): + """解析单个文件,返回 dict: 'ln_fwd_' -> sum, 'ln_bwd_total' -> combined sum""" + sums = {} + current = None + times = [] + header_pat = re.compile(r"^==\s*(.+?)\s*==$") + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + if current and times: + sums[current] = sum(times) + times = [] + continue + m = header_pat.match(line) + if m: + current = m.group(1) + times = [] + else: + try: + times.append(float(line)) + except ValueError: + pass + if current and times: + sums[current] = sum(times) + # 合并后两个 bwd kernels + bwd_sum = sums.get('ln_bwd_tuned_kernel', 0) + sums.get('ln_bwd_finalize', 0) + # 返回只有两项 + return { + 'ln_fwd_': sums.get('ln_fwd_', float('inf')), + 'ln_bwd_total': bwd_sum + } + +def find_minimums(dirpath): + """遍历目录文件,返回 dict: key -> (min_sum, filepath)""" + results = {} + for name in os.listdir(dirpath): + fp = os.path.join(dirpath, name) + if not os.path.isfile(fp): + continue + file_sums = parse_file(fp) + for key, val in file_sums.items(): + if key not in results or val < results[key][0]: + results[key] = (val, fp) + return results + +def main(): + if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + d = sys.argv[1] + if not os.path.isdir(d): + print(f"Error: {d} is not a directory") + sys.exit(1) + mins = find_minimums(d) + if not mins: + print("No valid files found.") + return + print("最小时间和结果:") + for key in ['ln_fwd_', 'ln_bwd_total']: + val, fp = mins.get(key, (None, None)) + if val is None: + print(f"- {key}: 无数据") + else: + print(f"- {key}: {val:.2f} 文件: {fp}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tuning_tools/launcher_ge.py b/tuning_tools/launcher_ge.py new file mode 100644 index 000000000..6a0f00357 --- /dev/null +++ b/tuning_tools/launcher_ge.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +脚本:针对指定 HIDDEN_SIZE/WTYPE/ITYPE/OTYPE/CTYPE,在 ln_fwd_cuda_kernel.cu 中批量替换 REGISTER_NORM_LAUNCHER 宏的 +CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 四个参数组合。 +只替换匹配该前缀的行,保留其他注册宏不变。 +""" +import re,os +import subprocess + +# 需要替换的源文件路径 +SOURCE_FILE = '/home/TransformerEngine/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu' +RESULTS_DIR = '/home/tuned_fwd/768/f16f16' +# 隐藏大小列表 +hidden_sizes = [768] +# 构造前缀模板,format 时填入 hidden_size +PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, {hs}, fp16, fp16, fp16, fp32," +# PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, {hs}, fp16, fp16, fp16, fp32," + +# # 要测试的参数组合 +# ctas_per_row_list = [ 2] +# warps_m_list = [1] +# warps_n_list = [8] +# bytes_per_ldg_list= [4,8,16,32] + +ctas_per_row_list = [1] +warps_m_list = [2,1] +warps_n_list = [2,4,8] +bytes_per_ldg_list= [8,16] +# 批量替换 +for hs in hidden_sizes: + # 每个 hidden_size 生成对应前缀 + prefix = PREFIX_TMPL.format(hs=hs) + for ctas in ctas_per_row_list: + for wm in warps_m_list: + for wn in warps_n_list: + for bpl in bytes_per_ldg_list: + if wm * wn < 2: + continue + lhs = hs // (bpl // 2) + rhs = ctas * wn * 32 * (lhs // (ctas * wn * 32)) + # rhs = 1 * wn * 32 * (lhs // (1 * wn * 32)) + if lhs != rhs: + continue + # if not (ctas == 1 or wm == 1): + # continue + # 构造新的完整宏调用行 + new_line = f"{prefix} {ctas}, {wm}, {wn}, {bpl});"#bwd + # 读取源文件 + with open(SOURCE_FILE, 'r', encoding='utf-8') as f: + lines = f.readlines() + # 写回时替换匹配前缀的行 + with open(SOURCE_FILE, 'w', encoding='utf-8') as f: + for line in lines: + if line.strip().startswith(prefix): + f.write(new_line + '\n') + else: + f.write(line) + print(f"Updated {SOURCE_FILE} for hidden_size={hs} with: WARPS_M={wm}, WARPS_N={wn}, BYTES_PER_LDG={bpl}") + + result=subprocess.run(['bash', './abs_do_fwd.sh']) + if result.returncode != 0: + print(f"Warning: abs_do.sh failed with exit code {result.returncode}") + + +proc = subprocess.run( + ['python3', 'find_fast.py', RESULTS_DIR], + stdout=subprocess.PIPE, + text=True, + check=True +) + +best_fp = None +for line in proc.stdout.splitlines(): + if line.startswith('- ln_fwd_'): + # 解析 “文件: /path/to/2048_12288_1_1_8_32” + parts = line.split('文件:') + if len(parts) == 2: + best_fp = parts[1].strip() + break + +if not best_fp: + print("Error: 没有找到最佳 ln_fwd_ 结果,退出。") + sys.exit(1) + +best_name = os.path.basename(best_fp) # e.g. "2048_12288_1_1_8_32" +print("Best ln_fwd file:", best_name) + +# —— 3. 从文件名拆出参数,并在 .cu 中替换宏行 —— # +tokens = best_name.split('_') +if len(tokens) != 6: + print("Error: 无法解析文件名参数:", best_name) + sys.exit(1) + +hs2, n2, ctas2, wm2, wn2, bpl2 = tokens +prefix = PREFIX_TMPL.format(hs=hs2) +new_line = f"{prefix} {ctas2}, {wm2}, {wn2}, {bpl2});" + +# 读源文件、替换所有匹配 prefix 的行 +with open(SOURCE_FILE, 'r', encoding='utf-8') as f: + lines = f.readlines() +with open(SOURCE_FILE, 'w', encoding='utf-8') as f: + for line in lines: + if line.strip().startswith(prefix): + f.write(new_line + '\n') + else: + f.write(line) + +print("已将所有前缀行替换为最佳组合:") +print(" ", new_line) \ No newline at end of file