From 799f77679eb3e27d8d9d182cdb168ce4020a3b2d Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Jun 2025 09:11:07 +0000 Subject: [PATCH 01/10] initial commit --- abs_do.sh | 34 +++++++++++ launcher_loop.py | 58 +++++++++++++++++++ tests/cpp/CMakeLists.txt | 2 +- tests/cpp/do.sh | 6 ++ tests/cpp/operator/CMakeLists.txt | 30 +++++----- tests/cpp/operator/test_normalization.cu | 43 +++++++++++--- .../layernorm/ln_fwd_cuda_kernel.cu | 13 +++-- 7 files changed, 158 insertions(+), 28 deletions(-) create mode 100644 abs_do.sh create mode 100644 launcher_loop.py create mode 100755 tests/cpp/do.sh diff --git a/abs_do.sh b/abs_do.sh new file mode 100644 index 000000000..f70362577 --- /dev/null +++ b/abs_do.sh @@ -0,0 +1,34 @@ +set -euo pipefail + +pip install . + +#等待编译结束 + +cd tests/cpp/build/ +rm -rf * +cmake .. +make + +# 运行 rocprof 并把输出既打印到屏幕又保存到临时文件 +ROCLOG=/tmp/rocprof.log +rocprof --stats ./operator/test_operator | tee "$ROCLOG" + +# 从 rocprof 输出中提取两组数字 Dimension(2048,12288) +shape_line=$(grep -m 1 'OperatorTest/NormTestSuite.TestNorm/LayerNorm_' "$ROCLOG") +dim1=$(awk -F'X' '{print $3}' <<<"$shape_line") +dim2=$(awk -F'X' '{print $4}' <<<"$shape_line") + +# 再提取 ctas_per_row, warps_n, bytes_per_load +ctas=$(grep -m 1 'ctas_per_row:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') +wm=$(grep -m 1 'warps_m:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') +wn=$(grep -m 1 'warps_n:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') +bpl=$(grep -m 1 'bytes_per_load:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') + +# 拼成文件名并创建空文件 +filename="${dim1}_${dim2}_${ctas}_${wm}_${wn}_${bpl}" +touch "/home/tuned_fwd/1024/f16f16/$filename" +echo "→ Created file $filename" + +python /home/tools/abs_readall.py "/home/tuned_fwd/1024/f16f16/${filename}" + + diff --git a/launcher_loop.py b/launcher_loop.py new file mode 100644 index 000000000..80366aa87 --- /dev/null +++ b/launcher_loop.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +""" +脚本:针对指定 HIDDEN_SIZE/WTYPE/ITYPE/OTYPE/CTYPE,在 ln_fwd_cuda_kernel.cu 中批量替换 REGISTER_NORM_LAUNCHER 宏的 +CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 四个参数组合。 +只替换匹配该前缀的行,保留其他注册宏不变。 +""" +import re +import subprocess + +# 需要替换的源文件路径 +SOURCE_FILE = '/home/TransformerEngine/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu' + +# 隐藏大小列表 +hidden_sizes = [1024] +# 构造前缀模板,format 时填入 hidden_size +PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, {hs}, fp16, fp16, fp16, fp32," + +# # 要测试的参数组合 +# ctas_per_row_list = [ 2] +# warps_m_list = [1] +# warps_n_list = [4] +# bytes_per_ldg_list= [16,64] + +ctas_per_row_list = [1,2] +warps_m_list = [1,4] +warps_n_list = [1,4] +bytes_per_ldg_list= [4,8,16,32] +# 批量替换 +for hs in hidden_sizes: + # 每个 hidden_size 生成对应前缀 + prefix = PREFIX_TMPL.format(hs=hs) + for ctas in ctas_per_row_list: + for wm in warps_m_list: + for wn in warps_n_list: + for bpl in bytes_per_ldg_list: + if wm * wn != 4: + continue + lhs = hs // (bpl // 2) + rhs = ctas * wn * 32 * (lhs // (ctas * wn * 32)) + if lhs != rhs: + continue + if not (ctas == 1 or wm == 1): + continue + # 构造新的完整宏调用行 + new_line = f"{prefix} {ctas}, {wm}, {wn}, {bpl});"#bwd + # 读取源文件 + with open(SOURCE_FILE, 'r', encoding='utf-8') as f: + lines = f.readlines() + # 写回时替换匹配前缀的行 + with open(SOURCE_FILE, 'w', encoding='utf-8') as f: + for line in lines: + if line.strip().startswith(prefix): + f.write(new_line + '\n') + else: + f.write(line) + print(f"Updated {SOURCE_FILE} for hidden_size={hs} with: CTAS_PER_ROW={ctas}, WARPS_M={wm}, WARPS_N={wn}, BYTES_PER_LDG={bpl}") + + subprocess.run(['bash', './abs_do.sh'], check=True) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 1f510f34e..92a6c4538 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -96,4 +96,4 @@ else() endif() add_subdirectory(operator) -add_subdirectory(util) +# add_subdirectory(util) diff --git a/tests/cpp/do.sh b/tests/cpp/do.sh new file mode 100755 index 000000000..ab548c876 --- /dev/null +++ b/tests/cpp/do.sh @@ -0,0 +1,6 @@ +rm -rf * + +cmake .. +make + +rocprof --stats ./operator/test_operator diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 728d37a17..3d1077914 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -5,22 +5,22 @@ # See LICENSE for license information. list(APPEND test_cuda_sources - test_qdq.cu - test_cast_transpose.cu - test_transpose.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_causal_softmax.cu + # test_qdq.cu + # test_cast_transpose.cu + # test_transpose.cu + # test_cast_transpose_dbias.cu + # test_cast_transpose_dbias_dgelu.cu + # test_cast_transpose_dgeglu.cu + # test_act.cu + test_normalization.cu + # test_multi_cast_transpose.cu + # test_multi_padding.cu + # test_causal_softmax.cu ../test_common.cu) -if(USE_ROCM) - list(APPEND test_cuda_sources - test_cublaslt_gemm.cu) -endif() +# if(USE_ROCM) +# list(APPEND test_cuda_sources +# test_cublaslt_gemm.cu) +# endif() if(USE_CUDA) add_executable(test_operator ${test_cuda_sources}) diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 616da6f22..e8b8a59c0 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -242,6 +242,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); + for(int i=0;i<5;i++) + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + for(int i=0;i<10;i++) nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -252,6 +257,13 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); + for(int i=0;i<5;i++) + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + for(int i=0;i<10;i++) nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -340,10 +352,27 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, } std::vector> test_cases = { - {71, 229}, - {29, 541}, - {768, 6144}, - {2048, 12288}, + // {71, 229}, + // {29, 541}, + // {768, 6144}, + //{2048, 12288}, + // {6800,928}, + // {16000,1472} + // {46000,2240}, + // {8200,2011} + // {14000,1485}, + // {34003,3957} + //{71,3571} + //{168,184} + {768,1024}, + // {256,65536}, + // {128,6144}, + // {64,2304}, + // {229,541}, + // {71, 3571}, + // {29, 17389}, + //{6048,16320} + //{76800,1600} }; } // namespace @@ -382,9 +411,9 @@ INSTANTIATE_TEST_SUITE_P( #else ::testing::Values(false), //TODO: enabling tests for cudnn backend #endif - ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), + ::testing::Values(NormType::LayerNorm), + ::testing::Values(DType::kFloat16), + ::testing::Values(DType::kFloat16), ::testing::ValuesIn(test_cases), ::testing::Values(false, true)), [](const testing::TestParamInfo& info) { diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 42a4874f4..62e8e31d9 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -9,7 +9,7 @@ #include "../common.h" #include "../kernel_traits.h" #include "ln_fwd_kernels.cuh" - +#include using namespace transformer_engine::normalization; template &launch_params, } return; } - + std::cout<<"ctas_per_row:"<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -204,7 +207,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1 REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 2, 1, 4, 8); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); @@ -222,7 +225,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); @@ -252,7 +255,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 2, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); From d39a9e2138e60d04110797eada852474bfccd6dc Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Jul 2025 07:09:52 +0000 Subject: [PATCH 02/10] final pr v1 --- abs_do.sh => abs_do_fwd.sh | 6 +- find_fast.py | 79 +++++++ launcher_ge.py | 109 ++++++++++ launcher_loop.py | 58 ----- tests/cpp/operator/test_normalization.cu | 13 +- .../common/normalization/kernel_traits.h | 3 +- .../layernorm/ln_bwd_kernels.cuh | 32 ++- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 69 ++++-- .../layernorm/ln_fwd_cuda_kernel.cu | 39 ++-- .../layernorm/ln_fwd_kernels.cuh | 132 +++++------- transformer_engine/common/utils.cuh | 200 ++++++++++++++++++ 11 files changed, 542 insertions(+), 198 deletions(-) rename abs_do.sh => abs_do_fwd.sh (85%) create mode 100644 find_fast.py create mode 100644 launcher_ge.py delete mode 100644 launcher_loop.py diff --git a/abs_do.sh b/abs_do_fwd.sh similarity index 85% rename from abs_do.sh rename to abs_do_fwd.sh index f70362577..ad16266f2 100644 --- a/abs_do.sh +++ b/abs_do_fwd.sh @@ -26,9 +26,11 @@ bpl=$(grep -m 1 'bytes_per_load:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print # 拼成文件名并创建空文件 filename="${dim1}_${dim2}_${ctas}_${wm}_${wn}_${bpl}" -touch "/home/tuned_fwd/1024/f16f16/$filename" +# filename="${dim1}_${dim2}_${wm}_${wn}_${bpl}" +touch "/home/tuned_fwd/768/f16f16/$filename" echo "→ Created file $filename" -python /home/tools/abs_readall.py "/home/tuned_fwd/1024/f16f16/${filename}" +python /home/tools/abs_readall.py "/home/tuned_fwd/768/f16f16/${filename}" + diff --git a/find_fast.py b/find_fast.py new file mode 100644 index 000000000..5ccf12413 --- /dev/null +++ b/find_fast.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" +脚本:遍历指定目录下所有文件,解析每个文件中 +- ln_fwd_ kernel 的时间之和 +- 将 ln_bwd_tuned_kernel 和 ln_bwd_finalize 两个 kernel 的时间之和合并为一个值 +然后在所有文件中分别找出 ln_fwd_ 和合并后的 bwd 的最小值及对应文件,输出结果。 +""" +import os +import sys +import re + +def parse_file(filepath): + """解析单个文件,返回 dict: 'ln_fwd_' -> sum, 'ln_bwd_total' -> combined sum""" + sums = {} + current = None + times = [] + header_pat = re.compile(r"^==\s*(.+?)\s*==$") + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + if current and times: + sums[current] = sum(times) + times = [] + continue + m = header_pat.match(line) + if m: + current = m.group(1) + times = [] + else: + try: + times.append(float(line)) + except ValueError: + pass + if current and times: + sums[current] = sum(times) + # 合并后两个 bwd kernels + bwd_sum = sums.get('ln_bwd_tuned_kernel', 0) + sums.get('ln_bwd_finalize', 0) + # 返回只有两项 + return { + 'ln_fwd_': sums.get('ln_fwd_', float('inf')), + 'ln_bwd_total': bwd_sum + } + +def find_minimums(dirpath): + """遍历目录文件,返回 dict: key -> (min_sum, filepath)""" + results = {} + for name in os.listdir(dirpath): + fp = os.path.join(dirpath, name) + if not os.path.isfile(fp): + continue + file_sums = parse_file(fp) + for key, val in file_sums.items(): + if key not in results or val < results[key][0]: + results[key] = (val, fp) + return results + +def main(): + if len(sys.argv) != 2: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + d = sys.argv[1] + if not os.path.isdir(d): + print(f"Error: {d} is not a directory") + sys.exit(1) + mins = find_minimums(d) + if not mins: + print("No valid files found.") + return + print("最小时间和结果:") + for key in ['ln_fwd_', 'ln_bwd_total']: + val, fp = mins.get(key, (None, None)) + if val is None: + print(f"- {key}: 无数据") + else: + print(f"- {key}: {val:.2f} 文件: {fp}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/launcher_ge.py b/launcher_ge.py new file mode 100644 index 000000000..6a0f00357 --- /dev/null +++ b/launcher_ge.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +脚本:针对指定 HIDDEN_SIZE/WTYPE/ITYPE/OTYPE/CTYPE,在 ln_fwd_cuda_kernel.cu 中批量替换 REGISTER_NORM_LAUNCHER 宏的 +CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 四个参数组合。 +只替换匹配该前缀的行,保留其他注册宏不变。 +""" +import re,os +import subprocess + +# 需要替换的源文件路径 +SOURCE_FILE = '/home/TransformerEngine/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu' +RESULTS_DIR = '/home/tuned_fwd/768/f16f16' +# 隐藏大小列表 +hidden_sizes = [768] +# 构造前缀模板,format 时填入 hidden_size +PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, {hs}, fp16, fp16, fp16, fp32," +# PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, {hs}, fp16, fp16, fp16, fp32," + +# # 要测试的参数组合 +# ctas_per_row_list = [ 2] +# warps_m_list = [1] +# warps_n_list = [8] +# bytes_per_ldg_list= [4,8,16,32] + +ctas_per_row_list = [1] +warps_m_list = [2,1] +warps_n_list = [2,4,8] +bytes_per_ldg_list= [8,16] +# 批量替换 +for hs in hidden_sizes: + # 每个 hidden_size 生成对应前缀 + prefix = PREFIX_TMPL.format(hs=hs) + for ctas in ctas_per_row_list: + for wm in warps_m_list: + for wn in warps_n_list: + for bpl in bytes_per_ldg_list: + if wm * wn < 2: + continue + lhs = hs // (bpl // 2) + rhs = ctas * wn * 32 * (lhs // (ctas * wn * 32)) + # rhs = 1 * wn * 32 * (lhs // (1 * wn * 32)) + if lhs != rhs: + continue + # if not (ctas == 1 or wm == 1): + # continue + # 构造新的完整宏调用行 + new_line = f"{prefix} {ctas}, {wm}, {wn}, {bpl});"#bwd + # 读取源文件 + with open(SOURCE_FILE, 'r', encoding='utf-8') as f: + lines = f.readlines() + # 写回时替换匹配前缀的行 + with open(SOURCE_FILE, 'w', encoding='utf-8') as f: + for line in lines: + if line.strip().startswith(prefix): + f.write(new_line + '\n') + else: + f.write(line) + print(f"Updated {SOURCE_FILE} for hidden_size={hs} with: WARPS_M={wm}, WARPS_N={wn}, BYTES_PER_LDG={bpl}") + + result=subprocess.run(['bash', './abs_do_fwd.sh']) + if result.returncode != 0: + print(f"Warning: abs_do.sh failed with exit code {result.returncode}") + + +proc = subprocess.run( + ['python3', 'find_fast.py', RESULTS_DIR], + stdout=subprocess.PIPE, + text=True, + check=True +) + +best_fp = None +for line in proc.stdout.splitlines(): + if line.startswith('- ln_fwd_'): + # 解析 “文件: /path/to/2048_12288_1_1_8_32” + parts = line.split('文件:') + if len(parts) == 2: + best_fp = parts[1].strip() + break + +if not best_fp: + print("Error: 没有找到最佳 ln_fwd_ 结果,退出。") + sys.exit(1) + +best_name = os.path.basename(best_fp) # e.g. "2048_12288_1_1_8_32" +print("Best ln_fwd file:", best_name) + +# —— 3. 从文件名拆出参数,并在 .cu 中替换宏行 —— # +tokens = best_name.split('_') +if len(tokens) != 6: + print("Error: 无法解析文件名参数:", best_name) + sys.exit(1) + +hs2, n2, ctas2, wm2, wn2, bpl2 = tokens +prefix = PREFIX_TMPL.format(hs=hs2) +new_line = f"{prefix} {ctas2}, {wm2}, {wn2}, {bpl2});" + +# 读源文件、替换所有匹配 prefix 的行 +with open(SOURCE_FILE, 'r', encoding='utf-8') as f: + lines = f.readlines() +with open(SOURCE_FILE, 'w', encoding='utf-8') as f: + for line in lines: + if line.strip().startswith(prefix): + f.write(new_line + '\n') + else: + f.write(line) + +print("已将所有前缀行替换为最佳组合:") +print(" ", new_line) \ No newline at end of file diff --git a/launcher_loop.py b/launcher_loop.py deleted file mode 100644 index 80366aa87..000000000 --- a/launcher_loop.py +++ /dev/null @@ -1,58 +0,0 @@ -#!/usr/bin/env python3 -""" -脚本:针对指定 HIDDEN_SIZE/WTYPE/ITYPE/OTYPE/CTYPE,在 ln_fwd_cuda_kernel.cu 中批量替换 REGISTER_NORM_LAUNCHER 宏的 -CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 四个参数组合。 -只替换匹配该前缀的行,保留其他注册宏不变。 -""" -import re -import subprocess - -# 需要替换的源文件路径 -SOURCE_FILE = '/home/TransformerEngine/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu' - -# 隐藏大小列表 -hidden_sizes = [1024] -# 构造前缀模板,format 时填入 hidden_size -PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, {hs}, fp16, fp16, fp16, fp32," - -# # 要测试的参数组合 -# ctas_per_row_list = [ 2] -# warps_m_list = [1] -# warps_n_list = [4] -# bytes_per_ldg_list= [16,64] - -ctas_per_row_list = [1,2] -warps_m_list = [1,4] -warps_n_list = [1,4] -bytes_per_ldg_list= [4,8,16,32] -# 批量替换 -for hs in hidden_sizes: - # 每个 hidden_size 生成对应前缀 - prefix = PREFIX_TMPL.format(hs=hs) - for ctas in ctas_per_row_list: - for wm in warps_m_list: - for wn in warps_n_list: - for bpl in bytes_per_ldg_list: - if wm * wn != 4: - continue - lhs = hs // (bpl // 2) - rhs = ctas * wn * 32 * (lhs // (ctas * wn * 32)) - if lhs != rhs: - continue - if not (ctas == 1 or wm == 1): - continue - # 构造新的完整宏调用行 - new_line = f"{prefix} {ctas}, {wm}, {wn}, {bpl});"#bwd - # 读取源文件 - with open(SOURCE_FILE, 'r', encoding='utf-8') as f: - lines = f.readlines() - # 写回时替换匹配前缀的行 - with open(SOURCE_FILE, 'w', encoding='utf-8') as f: - for line in lines: - if line.strip().startswith(prefix): - f.write(new_line + '\n') - else: - f.write(line) - print(f"Updated {SOURCE_FILE} for hidden_size={hs} with: CTAS_PER_ROW={ctas}, WARPS_M={wm}, WARPS_N={wn}, BYTES_PER_LDG={bpl}") - - subprocess.run(['bash', './abs_do.sh'], check=True) diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index e8b8a59c0..91a280fa2 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -346,6 +346,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, double atol_bwd = 5e-4; double rtol_bwd = 5e-4; + // double atol_bwd = 2e-3; + // double rtol_bwd = 2e-3; compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); @@ -356,22 +358,15 @@ std::vector> test_cases = { // {29, 541}, // {768, 6144}, //{2048, 12288}, - // {6800,928}, - // {16000,1472} - // {46000,2240}, - // {8200,2011} - // {14000,1485}, - // {34003,3957} //{71,3571} //{168,184} - {768,1024}, + // {768,1024}, // {256,65536}, // {128,6144}, // {64,2304}, // {229,541}, // {71, 3571}, - // {29, 17389}, - //{6048,16320} + {512,768} //{76800,1600} }; diff --git a/transformer_engine/common/normalization/kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h index 78d9212de..97e47c686 100644 --- a/transformer_engine/common/normalization/kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -67,6 +67,7 @@ struct Kernel_traits_finalize : public Base { template , typename Base = Kernel_traits_base > @@ -120,7 +121,7 @@ struct Kernel_traits : public Base { static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); // static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - using Stats = transformer_engine::Stats; + using Stats = StatsT; enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; }; diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index a13976e6f..1c5d95744 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -227,16 +227,15 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; + const uint32_t COL_STRIDE = params.cols * THREADS_PER_WARP; + for (uint32_t col = c, col_out = c_out; col < params.cols; col += COL_STRIDE, col_out += COL_STRIDE / 2) { // Each thread sums over NUM_ELT columns. Vec dbeta_local, dgamma_local; memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dbeta_local, 0, sizeof(dbeta_local)); for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { - index_t idx = row * Kernel_traits::COLS + col; - + index_t idx = row * params.cols + col; Vec dbeta_part, dgamma_part; dbeta_part.load_from(params.dbeta_part, idx); dgamma_part.load_from(params.dgamma_part, idx); @@ -391,7 +390,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne } Cvec dy[LDGS]; - Cvec y[LDGS]; + //Cvec y[LDGS]; compute_t mdy = 0.f; compute_t mdyy = 0.f; @@ -411,14 +410,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne const compute_t dz_ij = dz.data.elt[jt]; const compute_t dy_ij = g_ij * dz_ij; - y[it].data.elt[jt] = y_ij; + //y[it].data.elt[jt] = y_ij; dy[it].data.elt[jt] = dy_ij; mdy += dy_ij; mdyy += dy_ij * y_ij; - dz_sum[it].data.elt[jt] += dz_ij; - dzy_sum[it].data.elt[jt] += dz_ij * y_ij; + // dz_sum[it].data.elt[jt] += dz_ij; + // dzy_sum[it].data.elt[jt] += dz_ij * y_ij; } } @@ -432,11 +431,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; it++, col += gdimn * NUM_ELTS) { Ivec dx; + + Ivec x; + Ovec dz; + x.load_from_elts(params.x, row * params.cols + col, params.cols - col); + dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); + #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t dy_ij = dy[it].data.elt[jt]; - compute_t y_ij = y[it].data.elt[jt]; - dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy)); + const compute_t x_ij = x.data.elt[jt]; + const compute_t y_ij = rs * (x_ij - mu); + const compute_t dz_ij = dz.data.elt[jt]; + + dx.data.elt[jt] = rs * (dy[it].data.elt[jt] - (mdyy * y_ij + mdy)); + + dz_sum[it].data.elt[jt] += dz_ij; + dzy_sum[it].data.elt[jt] += dz_ij * y_ij; } dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); } diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 09618c58d..74d62468c 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -10,6 +10,7 @@ #include "../common.h" #include "../kernel_traits.h" #include "ln_bwd_kernels.cuh" +#include using namespace transformer_engine::normalization; @@ -39,7 +40,9 @@ static void launch_tuned_(LaunchParams &launch_params, launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); return; } - + // std::cout<<"bwd ctas_per_row:"<< CTAS_PER_ROW<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -106,7 +109,10 @@ static void launch_general_(LaunchParams &launch_params, launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t); return; } - + // std::cout<<"bwd cols:"< &launch_params, reinterpret_cast(¶ms_), 0, stream); } - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &ln_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); + // Decide which finalize kernel to launch based on column alignment + const bool cols_aligned = (cols % 32 == 0); + + if (cols_aligned) { + // Launch tuned finalize kernel + using Kernel_traits_f = Kernel_traits_finalize; + + auto kernel_f = &ln_bwd_finalize_tuned_kernel; + + + kernel_f<<>>( + launch_params.params); + + } else { + // Launch general finalize kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL) / + sizeof(compute_t); + + auto kernel_final = &ln_bwd_finalize_general_kernel; + + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + + kernel_final<<>>(launch_params.params); + } } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -157,7 +186,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 8); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); @@ -165,11 +194,11 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);// REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 2, 1, 1, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); @@ -223,7 +252,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 1, 1, 16, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); @@ -295,7 +324,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 4, 1, 16, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); @@ -317,13 +346,13 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 2, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 32, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 62e8e31d9..a4723163d 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -35,10 +35,10 @@ static void launch_tuned_(LaunchParams &launch_params, } return; } - std::cout<<"ctas_per_row:"<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -66,7 +66,8 @@ template &launch_params, const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; + 1, WARPS_M, WARPS_N, BYTES_PER_LDG, + transformer_engine::Stats_ge>; auto kernel = &ln_fwd_general_kernel; auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; @@ -91,7 +92,9 @@ static void launch_general_(LaunchParams &launch_params, } return; } - + // std::cout<<"warps_m:"<; - constexpr int SMEM_BYTES = Reducer::SMEM_BYTES > 0 ? Reducer::SMEM_BYTES : 1; - __shared__ char smem_[SMEM_BYTES]; - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); - Sum sum; - const compute_t rn = 1.f / static_cast(params.cols); - - // Load weights - Cvec gamma[LDGS]; - Cvec beta[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; - ++it, col += gdimn * NUM_ELTS) { - Wvec gamma_in, beta_in; - gamma_in.load_from_elts(params.gamma, col, params.cols - col); - beta_in.load_from_elts(params.beta, col, params.cols - col); - gamma_in.to(gamma[it]); - beta_in.to(beta[it]); - } + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + extern __shared__ char smem[]; + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem); - // fp8 factors - compute_t scale; - if (params.fp8_out) { - scale = *reinterpret_cast(params.scale); - } + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + compute_t scale = params.fp8_out ? *reinterpret_cast(params.scale) : 1.f; compute_t amax = 0; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { - const int row = cta_row + warp_m; + int row = cta_row + warp_m; + if (row >= params.rows) continue; - // Load input - Cvec x[LDGS]; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - Ivec x_in; - x_in.load_from_elts(params.x, row * params.cols + col, params.cols - col); - x_in.to(x[it]); - } + compute_t mu = 0.f, m2 = 0.f; + int count = 0; - // Compute mean - compute_t mu = 0.f; + // Step 1: mean and m2 #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS; + ++it, col += gdimn * NUM_ELTS) { + Ivec x_vec; + x_vec.load_from_elts(params.x, row * params.cols + col, params.cols - col); #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - mu += x[it].data.elt[jt]; - } - } - mu = reducer.allreduce(mu, sum) * rn; - - // Compute variance - compute_t sqsigma = 0.f; -#pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { + for (int jt = 0; jt < NUM_ELTS; ++jt) { if (col + jt < params.cols) { - compute_t diff = x[it].data.elt[jt] - mu; - sqsigma += diff * diff; + compute_t x = compute_t(x_vec.data.elt[jt]); + count += 1; + compute_t delta = x - mu; + mu += delta / count; + m2 += delta * (x - mu); } } } - sqsigma = reducer.allreduce(sqsigma, sum) * rn; - compute_t rs = rsqrtf(sqsigma + params.epsilon); - // Write statistics - if (gidn == 0 && row < params.rows) { - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); + Vec3 stat = stats.reduce(Vec3(mu, m2, count)); + mu = stat.x; + m2 = stat.y; + compute_t rs = rsqrtf((m2 / stat.z) + params.epsilon); + + if (gidn == 0) { mu_ptr[row] = mu; rs_ptr[row] = rs; } -// Compute output + // Step 2: store output (no need to store xf[]) #pragma unroll - for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; - it++, col += gdimn * NUM_ELTS) { - // Compute output values - Cvec z; -#pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - compute_t y_ij = rs * (x[it].data.elt[jt] - mu); - compute_t g_ij = gamma[it].data.elt[jt]; - if (params.zero_centered_gamma) { - g_ij += 1; - } - compute_t b_ij = beta[it].data.elt[jt]; - z.data.elt[jt] = g_ij * y_ij + b_ij; - } + for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && col < params.cols; + ++it, col += gdimn * NUM_ELTS) { + Ivec x_vec; + x_vec.load_from_elts(params.x, row * params.cols + col, params.cols - col); + Wvec g_raw, b_raw; + g_raw.load_from_elts(params.gamma, col, params.cols - col); + b_raw.load_from_elts(params.beta, col, params.cols - col); - // Apply fp8 factors - if (params.fp8_out) { + Cvec z; #pragma unroll - for (int jt = 0; jt < NUM_ELTS; jt++) { - if (col + jt < params.cols) { - compute_t z_ij = z.data.elt[jt]; - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + for (int jt = 0; jt < NUM_ELTS; ++jt) { + if (col + jt < params.cols) { + compute_t x = compute_t(x_vec.data.elt[jt]); + compute_t norm = rs * (x - mu); + compute_t g = compute_t(g_raw.data.elt[jt]) + (params.zero_centered_gamma ? 1.f : 0.f); + compute_t b = compute_t(b_raw.data.elt[jt]); + compute_t val = g * norm + b; + if (params.fp8_out) { + amax = fmaxf(amax, fabsf(val)); + val *= scale; } + z.data.elt[jt] = output_t(val); } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index def7fb08f..922a116fa 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -124,6 +124,16 @@ struct uint8 { template struct BytesToType {}; +// 新增对 128 字节的支持:以 16 个 uint8x8 为例(16*8=128B) +struct uint8x8 { uint8_t data[8]; }; +struct uint8x8x16 { uint8x8 v[16]; }; + +template<> +struct BytesToType<128> { + using Type = uint8x8x16; + static_assert(sizeof(Type) == 128, "BytesToType<128> must be 128 bytes"); +}; + template <> struct BytesToType<64> { using Type = uint16; @@ -166,7 +176,26 @@ struct BytesToType<1> { static_assert(sizeof(Type) == 1); }; +template +struct Vec3 { + T x, y; + CountT z; + + __device__ Vec3() : x(0), y(0), z(0) {} + __device__ Vec3(T x_, T y_, CountT z_) : x(x_), y(y_), z(z_) {} + + __device__ Vec3 &operator+=(const Vec3 &rhs) { + x += rhs.x; + y += rhs.y; + z += rhs.z; + return *this; + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct TypeToVec3 { + using Type = Vec3; +}; template struct TypeToVec2 {}; @@ -874,6 +903,177 @@ struct Stats { }; //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void warp_chan_upd_dynamic_ge(Vec3 &stat, int num_active) { + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + T n_b = warp_shuffle_down(stat.z, step); + T m_b = warp_shuffle_down(stat.x, step); + T m2_b = warp_shuffle_down(stat.y, step); + + T n_a = stat.z; + T m_a = stat.x; + T m2_a = stat.y; + + T n_ab = n_a + n_b; + T rn_ab = T(1.f) / n_ab; + T delta = m_a - m_b; + + T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + + stat = Vec3(m_ab, m2_ab, n_ab); + } + +#ifdef __HIP_PLATFORM_AMD__ + stat.x = __shfl(stat.x, 0, THREADS_PER_WARP); + stat.y = __shfl(stat.y, 0, THREADS_PER_WARP); + stat.z = __shfl(stat.z, 0, THREADS_PER_WARP); +#else + stat.x = __shfl_sync(static_cast(-1), stat.x, 0); + stat.y = __shfl_sync(static_cast(-1), stat.y, 0); + stat.z = __shfl_sync(static_cast(-1), stat.z, 0); +#endif +} + +template +struct Stats_ge; + + +// Warp-level Stats (Welford-based) +template +struct Stats_ge { + using stats_t = Vec3; // (mu, m2, count) + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t, uint32_t, + uint32_t, uint32_t warp_n, uint32_t lane, void *) + : warp_n_(warp_n), lane_(lane) {} + +// template +// inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { +// T mean = 0, m2 = 0, count = 0; +// #pragma unroll +// for (int i = 0; i < N; ++i) { +// if (i < valid_count) { +// T x = elts[i]; +// count += 1; +// T delta = x - mean; +// mean += delta / count; +// T delta2 = x - mean; +// m2 += delta * delta2; +// } +// } +// return reduce(Vec3(mean, m2, count)); +// } + + inline __device__ stats_t reduce(Vec3 local_stat) { + warp_chan_upd_dynamic_ge(local_stat, THREADS_PER_WARP); + return local_stat; + } + + uint32_t warp_n_, lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Block-level Stats (intra CTA warp reduction) +template +struct Stats_ge { + using stats_t = Vec3; + using WarpStats = Stats_ge; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + // template + // inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { + // Vec3 local = warp_stats_.compute(elts, valid_count); + // return reduce(local); + // } + + inline __device__ stats_t reduce(Vec3 local_stat) { + local_stat=warp_stats_.reduce(local_stat); + + stats_t *smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + if (warp_stats_.lane_ == 0) { + smem[warp_stats_.warp_n_] = local_stat; + } + __syncthreads(); + + stats_t result{Zeros::get(), Zeros::get(), Zeros::get()}; + if (warp_stats_.lane_ < WARPS_N) { + result = smem[warp_stats_.lane_]; + } + + warp_chan_upd_dynamic_ge(result, WARPS_N); + return result; + } + + WarpStats warp_stats_; + stats_t *smem0_, *smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inter-CTA Stats +template +struct Stats_ge { + using stats_t = Vec3; + using BlockStats = Stats_ge; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats_ge(const Params ¶ms, uint32_t bidm, uint32_t bidn, + uint32_t warp_m, uint32_t warp_n, uint32_t lane, void *smem) + : inter_cta_(params.barrier, bidm, params.ctas_per_col, CTAS_PER_ROW), + block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem), + bidn_(bidn), + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), lane_(lane) {} + + // template + // inline __device__ stats_t compute(const T (&elts)[N], int valid_count) { + // Vec3 local = block_stats_.compute(elts, valid_count); + // return reduce(local); + // } + + inline __device__ stats_t reduce(Vec3 local_stat) { + local_stat=block_stats_.reduce(local_stat); + stats_t *workspace = (inter_cta_.phase_counter_ & 0x1) ? w1_ : w0_; + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = local_stat; + } + inter_cta_.sync(); + + stats_t result{Zeros::get(), Zeros::get(), Zeros::get()}; + if (lane_ < CTAS_PER_ROW) { + result = workspace[lane_]; + } + + warp_chan_upd_dynamic_ge(result, CTAS_PER_ROW); + return result; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + stats_t *w0_, *w1_; + int bidn_, warp_n_, lane_; +}; template __device__ __forceinline__ float warp_reduce_max(const float m) { From 1cdac1b9059b3cdc120472066e4e2babb0580be6 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Jul 2025 08:14:30 +0000 Subject: [PATCH 03/10] pr v2 --- tests/cpp/do.sh | 6 -- .../layernorm/ln_fwd_cuda_kernel.cu | 2 +- abs_do_fwd.sh => tuning_tools/abs_do_fwd.sh | 6 +- tuning_tools/abs_readall.py | 66 +++++++++++++++++++ find_fast.py => tuning_tools/find_fast.py | 0 launcher_ge.py => tuning_tools/launcher_ge.py | 0 6 files changed, 69 insertions(+), 11 deletions(-) delete mode 100755 tests/cpp/do.sh rename abs_do_fwd.sh => tuning_tools/abs_do_fwd.sh (91%) create mode 100644 tuning_tools/abs_readall.py rename find_fast.py => tuning_tools/find_fast.py (100%) rename launcher_ge.py => tuning_tools/launcher_ge.py (100%) diff --git a/tests/cpp/do.sh b/tests/cpp/do.sh deleted file mode 100755 index ab548c876..000000000 --- a/tests/cpp/do.sh +++ /dev/null @@ -1,6 +0,0 @@ -rm -rf * - -cmake .. -make - -rocprof --stats ./operator/test_operator diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index a4723163d..3f84e9fc4 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -204,7 +204,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 1, 2, 8); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); diff --git a/abs_do_fwd.sh b/tuning_tools/abs_do_fwd.sh similarity index 91% rename from abs_do_fwd.sh rename to tuning_tools/abs_do_fwd.sh index ad16266f2..55eb2cdd8 100644 --- a/abs_do_fwd.sh +++ b/tuning_tools/abs_do_fwd.sh @@ -1,5 +1,6 @@ set -euo pipefail +cd .. pip install . #等待编译结束 @@ -30,7 +31,4 @@ filename="${dim1}_${dim2}_${ctas}_${wm}_${wn}_${bpl}" touch "/home/tuned_fwd/768/f16f16/$filename" echo "→ Created file $filename" -python /home/tools/abs_readall.py "/home/tuned_fwd/768/f16f16/${filename}" - - - +python /home/TransformerEngine/tuning_tools/abs_readall.py "/home/tuned_fwd/768/f16f16/${filename}" \ No newline at end of file diff --git a/tuning_tools/abs_readall.py b/tuning_tools/abs_readall.py new file mode 100644 index 000000000..b8cff57c3 --- /dev/null +++ b/tuning_tools/abs_readall.py @@ -0,0 +1,66 @@ +import json +import os +import sys +import argparse + +def extract_and_process_durations(input_file, output_file, kernel_keywords, num_warmup, num_iteration): + with open(input_file, "r") as f: + data = json.load(f) + + keyword_to_durations = {k: [] for k in kernel_keywords} + + for event in data.get("traceEvents", []): + args = event.get("args", {}) + kernel_name = args.get("KernelName", "") + duration = args.get("DurationNs") + + if duration is not None: + for keyword in kernel_keywords: + if keyword in kernel_name: + keyword_to_durations[keyword].append(int(duration)) + break # 防止同一个event被多个keyword重复统计 + + output_lines = [] + + for keyword in kernel_keywords: + durations = keyword_to_durations[keyword] + output_lines.append(f"== {keyword} ==") + + if not durations: + output_lines.append("[无数据]") + continue + + i = 0 + while i < len(durations): + i += num_warmup # 跳过warmup + batch = [] + for _ in range(num_iteration): + if i < len(durations): + batch.append(durations[i]) + i += 1 + if batch: + avg = sum(batch) / len(batch) + output_lines.append(f"{avg:.2f}") + output_lines.append("") # 空行分隔 + + with open(output_file, "w") as f: + f.write("\n".join(output_lines)) + + print(f"已将所有 kernel 的平均耗时写入 {output_file}") + +input_json = "/home/TransformerEngine/tests/cpp/build/results.json" +if len(sys.argv) > 1: + output_txt = sys.argv[1] +else: + output_txt = "/home/bwdprofiles/tmp/heyi.txt" + +kernel_keywords = [ + "ln_fwd_", + "ln_bwd_general_kernel", + "ln_bwd_finalize" +] + +num_warmup = 5 +num_iteration = 10 + +extract_and_process_durations(input_json, output_txt, kernel_keywords, num_warmup, num_iteration) \ No newline at end of file diff --git a/find_fast.py b/tuning_tools/find_fast.py similarity index 100% rename from find_fast.py rename to tuning_tools/find_fast.py diff --git a/launcher_ge.py b/tuning_tools/launcher_ge.py similarity index 100% rename from launcher_ge.py rename to tuning_tools/launcher_ge.py From 06149f5f0fc3a4666ff1dfe3708280816836a8bc Mon Sep 17 00:00:00 2001 From: root Date: Mon, 7 Jul 2025 02:48:50 +0000 Subject: [PATCH 04/10] clean code v1 --- tests/cpp/operator/CMakeLists.txt | 30 ++++++++++----------- tests/cpp/operator/test_normalization.cu | 33 ++++++++++++------------ 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 3d1077914..728d37a17 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -5,22 +5,22 @@ # See LICENSE for license information. list(APPEND test_cuda_sources - # test_qdq.cu - # test_cast_transpose.cu - # test_transpose.cu - # test_cast_transpose_dbias.cu - # test_cast_transpose_dbias_dgelu.cu - # test_cast_transpose_dgeglu.cu - # test_act.cu - test_normalization.cu - # test_multi_cast_transpose.cu - # test_multi_padding.cu - # test_causal_softmax.cu + test_qdq.cu + test_cast_transpose.cu + test_transpose.cu + test_cast_transpose_dbias.cu + test_cast_transpose_dbias_dgelu.cu + test_cast_transpose_dgeglu.cu + test_act.cu + test_normalization.cu + test_multi_cast_transpose.cu + test_multi_padding.cu + test_causal_softmax.cu ../test_common.cu) -# if(USE_ROCM) -# list(APPEND test_cuda_sources -# test_cublaslt_gemm.cu) -# endif() +if(USE_ROCM) + list(APPEND test_cuda_sources + test_cublaslt_gemm.cu) +endif() if(USE_CUDA) add_executable(test_operator ${test_cuda_sources}) diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 91a280fa2..3bf7fe8a5 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -242,11 +242,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); - for(int i=0;i<5;i++) - nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), - prop.multiProcessorCount, zero_centered_gamma, 0); - for(int i=0;i<10;i++) + // for(int i=0;i<5;i++) + // nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + // z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + // prop.multiProcessorCount, zero_centered_gamma, 0); + // for(int i=0;i<10;i++) nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -257,13 +257,13 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); - for(int i=0;i<5;i++) - nvte_layernorm_bwd(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - workspace_bwd.data(), - prop.multiProcessorCount, zero_centered_gamma, 0); - for(int i=0;i<10;i++) + // for(int i=0;i<5;i++) + // nvte_layernorm_bwd(dz.data(), input.data(), + // mu.data(), rsigma.data(), gamma.data(), + // dx.data(), dgamma.data(), dbeta.data(), + // workspace_bwd.data(), + // prop.multiProcessorCount, zero_centered_gamma, 0); + // for(int i=0;i<10;i++) nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -346,8 +346,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, double atol_bwd = 5e-4; double rtol_bwd = 5e-4; - // double atol_bwd = 2e-3; - // double rtol_bwd = 2e-3; + compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); @@ -406,9 +405,9 @@ INSTANTIATE_TEST_SUITE_P( #else ::testing::Values(false), //TODO: enabling tests for cudnn backend #endif - ::testing::Values(NormType::LayerNorm), - ::testing::Values(DType::kFloat16), - ::testing::Values(DType::kFloat16), + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::ValuesIn(test_cases), ::testing::Values(false, true)), [](const testing::TestParamInfo& info) { From b0d9050a5cd5fe804a489a83d9dd822998941c0e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 7 Jul 2025 03:17:50 +0000 Subject: [PATCH 05/10] clean code v2 --- tests/cpp/operator/test_normalization.cu | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 3bf7fe8a5..b77f3b8ff 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -242,11 +242,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); workspace_fwd = Tensor(workspace_fwd.shape(), workspace_fwd.dtype()); - // for(int i=0;i<5;i++) - // nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, - // z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), - // prop.multiProcessorCount, zero_centered_gamma, 0); - // for(int i=0;i<10;i++) nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -257,13 +252,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); workspace_bwd = Tensor(workspace_bwd.shape(), workspace_bwd.dtype()); - // for(int i=0;i<5;i++) - // nvte_layernorm_bwd(dz.data(), input.data(), - // mu.data(), rsigma.data(), gamma.data(), - // dx.data(), dgamma.data(), dbeta.data(), - // workspace_bwd.data(), - // prop.multiProcessorCount, zero_centered_gamma, 0); - // for(int i=0;i<10;i++) nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -346,7 +334,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, double atol_bwd = 5e-4; double rtol_bwd = 5e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); From 1e3519751644fceee82504903a712b7c487825f4 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Nov 2025 04:11:40 +0000 Subject: [PATCH 06/10] Resolve merge conflicts --- tests/cpp/operator/test_normalization.h | 45 +++++++++++++++++-- .../layernorm/ln_fwd_kernels.cuh | 8 +++- transformer_engine/common/utils.cuh | 35 ++++++++++++--- 3 files changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index 5f5603a7f..c949a0906 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -64,11 +64,45 @@ void compute_ref_stats(NormType norm_type, } } +// template +// inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { + +// using compute_t = float; + +// // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently +// // Remove the use_cudnn check here when it is supported by both backends. +// const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; + +// if constexpr (std::is_same_v || std::is_same_v){ +// compute_t g = static_cast(gamma); +// if (zero_centered_gamma) { +// g += static_cast(1.f); +// } +// return g; +// } else { +// if (zero_centered_gamma_in_weight_dtype){ +// compute_t g = static_cast(0.f); +// InputType gi = gamma; +// if (zero_centered_gamma) { +// gi = gi + static_cast(1.f); +// } +// g = static_cast(gi); +// return g; +// } else { +// compute_t g = static_cast(gamma); +// if (zero_centered_gamma) { +// g += static_cast(1.f); +// } +// return g; +// } +// } +// } + template -inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype) { +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn, const bool cudnn_zero_centered_gamma_in_weight_dtype){ using compute_t = float; - + // Zero-centered gamma in weight dtype is only supported in CuDNN backend currently // Remove the use_cudnn check here when it is supported by both backends. const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype; @@ -80,6 +114,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const } return g; } else { +#ifdef __HIP_PLATFORM_AMD__ + (void)zero_centered_gamma_in_weight_dtype; // Parameter is unused on AMD platform +#else if (zero_centered_gamma_in_weight_dtype){ compute_t g = static_cast(0.f); InputType gi = gamma; @@ -88,7 +125,9 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const } g = static_cast(gi); return g; - } else { + } else +#endif + { compute_t g = static_cast(gamma); if (zero_centered_gamma) { g += static_cast(1.f); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 3f1870233..b328061b5 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -251,10 +251,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } } + Vec3 stat = stats.reduce(Vec3(mu, m2, count)); mu = stat.x; m2 = stat.y; - compute_t rs = rsqrtf((m2 / stat.z) + params.epsilon); + + compute_t var = m2 / stat.z; + var = var < compute_t(0) ? compute_t(0) : var; + compute_t rs = rsqrtf(var + params.epsilon); + + // compute_t rs = rsqrtf((m2 / stat.z) + params.epsilon); if (gidn == 0) { mu_ptr[row] = mu; diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 7dc0a15e6..145d65341 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -893,6 +893,24 @@ inline __device__ void warp_chan_upd_dynamic_ge(Vec3 &stat, int num_ac int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); #pragma unroll + // for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + // T n_b = warp_shuffle_down(stat.z, step); + // T m_b = warp_shuffle_down(stat.x, step); + // T m2_b = warp_shuffle_down(stat.y, step); + + // T n_a = stat.z; + // T m_a = stat.x; + // T m2_a = stat.y; + + // T n_ab = n_a + n_b; + // T rn_ab = T(1.f) / n_ab; + // T delta = m_a - m_b; + + // T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + // T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + + // stat = Vec3(m_ab, m2_ab, n_ab); + // } for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { T n_b = warp_shuffle_down(stat.z, step); T m_b = warp_shuffle_down(stat.x, step); @@ -902,16 +920,21 @@ inline __device__ void warp_chan_upd_dynamic_ge(Vec3 &stat, int num_ac T m_a = stat.x; T m2_a = stat.y; - T n_ab = n_a + n_b; - T rn_ab = T(1.f) / n_ab; - T delta = m_a - m_b; + if(n_b == 0){} + else + { + T n_ab = n_a + n_b; + T rn_ab = T(1.f) / n_ab; + T delta = m_a - m_b; - T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - stat = Vec3(m_ab, m2_ab, n_ab); + stat = Vec3(m_ab, m2_ab, n_ab); + } } + #ifdef __HIP_PLATFORM_AMD__ stat.x = __shfl(stat.x, 0, THREADS_PER_WARP); stat.y = __shfl(stat.y, 0, THREADS_PER_WARP); From a4d1f8e27f58bf4f4690128700c2b09ab94e9a26 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Nov 2025 03:32:35 +0000 Subject: [PATCH 07/10] Merge branch 'dev' into dev --- transformer_engine/common/utils.cuh | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 145d65341..96f606c65 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -893,24 +893,6 @@ inline __device__ void warp_chan_upd_dynamic_ge(Vec3 &stat, int num_ac int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); #pragma unroll - // for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { - // T n_b = warp_shuffle_down(stat.z, step); - // T m_b = warp_shuffle_down(stat.x, step); - // T m2_b = warp_shuffle_down(stat.y, step); - - // T n_a = stat.z; - // T m_a = stat.x; - // T m2_a = stat.y; - - // T n_ab = n_a + n_b; - // T rn_ab = T(1.f) / n_ab; - // T delta = m_a - m_b; - - // T m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - // T m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - - // stat = Vec3(m_ab, m2_ab, n_ab); - // } for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { T n_b = warp_shuffle_down(stat.z, step); T m_b = warp_shuffle_down(stat.x, step); From a2c49e77c1d080428d02cff9479fdcaa8245e734 Mon Sep 17 00:00:00 2001 From: eliotwang <46883838+eliotwang@users.noreply.github.com> Date: Tue, 18 Nov 2025 11:34:35 +0800 Subject: [PATCH 08/10] Update CMakeLists.txt --- tests/cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 1705b428b..4ab5fd237 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -94,4 +94,4 @@ else() endif() add_subdirectory(operator) -# add_subdirectory(util) +add_subdirectory(util) From ab5bee75dddf56bb3d33bba70b594ca614f7f309 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 24 Nov 2025 03:20:19 +0000 Subject: [PATCH 09/10] Optimize te_ln_fwd kernel --- tests/cpp/operator/test_normalization.cu | 26 +++---- .../layernorm/ln_bwd_kernels.cuh | 32 +++------ .../layernorm/ln_bwd_semi_cuda_kernel.cu | 71 ++++++------------- .../layernorm/ln_fwd_cuda_kernel.cu | 9 +-- 4 files changed, 48 insertions(+), 90 deletions(-) diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index 5c19a79da..f81727572 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -35,9 +35,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, return; } +#ifndef __HIP_PLATFORM_AMD__ if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; } +#endif using WeightType = InputType; DType itype = TypeInfo::dtype; @@ -112,7 +114,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -218,17 +219,18 @@ std::vector> test_cases = { // {71, 229}, // {29, 541}, // {768, 6144}, - //{2048, 12288}, - //{71,3571} - //{168,184} - // {768,1024}, - // {256,65536}, - // {128,6144}, - // {64,2304}, - // {229,541}, - // {71, 3571}, - {512,768} - //{76800,1600} + {2048, 12288}, + {768,1024}, + {256,65536}, + {128,6144}, + {64,2304}, + {229,541}, + {71, 3571}, + {29,17389}, + {76800,1600} + // {512,768}, + // {71,3571}, + // {168,184} }; } // namespace diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index 1c5d95744..a13976e6f 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -227,15 +227,16 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz const uint32_t c = bidn * THREADS_PER_WARP + lane; const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - const uint32_t COL_STRIDE = params.cols * THREADS_PER_WARP; - for (uint32_t col = c, col_out = c_out; col < params.cols; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for (uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2) { // Each thread sums over NUM_ELT columns. Vec dbeta_local, dgamma_local; memset(&dgamma_local, 0, sizeof(dgamma_local)); memset(&dbeta_local, 0, sizeof(dbeta_local)); for (uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA) { - index_t idx = row * params.cols + col; + index_t idx = row * Kernel_traits::COLS + col; + Vec dbeta_part, dgamma_part; dbeta_part.load_from(params.dbeta_part, idx); dgamma_part.load_from(params.dgamma_part, idx); @@ -390,7 +391,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne } Cvec dy[LDGS]; - //Cvec y[LDGS]; + Cvec y[LDGS]; compute_t mdy = 0.f; compute_t mdyy = 0.f; @@ -410,14 +411,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne const compute_t dz_ij = dz.data.elt[jt]; const compute_t dy_ij = g_ij * dz_ij; - //y[it].data.elt[jt] = y_ij; + y[it].data.elt[jt] = y_ij; dy[it].data.elt[jt] = dy_ij; mdy += dy_ij; mdyy += dy_ij * y_ij; - // dz_sum[it].data.elt[jt] += dz_ij; - // dzy_sum[it].data.elt[jt] += dz_ij * y_ij; + dz_sum[it].data.elt[jt] += dz_ij; + dzy_sum[it].data.elt[jt] += dz_ij * y_ij; } } @@ -431,22 +432,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; it++, col += gdimn * NUM_ELTS) { Ivec dx; - - Ivec x; - Ovec dz; - x.load_from_elts(params.x, row * params.cols + col, params.cols - col); - dz.load_from_elts(params.dz, row * params.cols + col, params.cols - col); - #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { - const compute_t x_ij = x.data.elt[jt]; - const compute_t y_ij = rs * (x_ij - mu); - const compute_t dz_ij = dz.data.elt[jt]; - - dx.data.elt[jt] = rs * (dy[it].data.elt[jt] - (mdyy * y_ij + mdy)); - - dz_sum[it].data.elt[jt] += dz_ij; - dzy_sum[it].data.elt[jt] += dz_ij * y_ij; + compute_t dy_ij = dy[it].data.elt[jt]; + compute_t y_ij = y[it].data.elt[jt]; + dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij + mdy)); } dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); } diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 74d62468c..b17d864f1 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -10,7 +10,6 @@ #include "../common.h" #include "../kernel_traits.h" #include "ln_bwd_kernels.cuh" -#include using namespace transformer_engine::normalization; @@ -40,9 +39,7 @@ static void launch_tuned_(LaunchParams &launch_params, launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); return; } - // std::cout<<"bwd ctas_per_row:"<< CTAS_PER_ROW<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -109,10 +106,7 @@ static void launch_general_(LaunchParams &launch_params, launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t); return; } - // std::cout<<"bwd cols:"< &launch_params, reinterpret_cast(¶ms_), 0, stream); } - // Decide which finalize kernel to launch based on column alignment - const bool cols_aligned = (cols % 32 == 0); - - if (cols_aligned) { - // Launch tuned finalize kernel - using Kernel_traits_f = Kernel_traits_finalize; - - auto kernel_f = &ln_bwd_finalize_tuned_kernel; - - - kernel_f<<>>( - launch_params.params); - - } else { - // Launch general finalize kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL) / - sizeof(compute_t); - - auto kernel_final = &ln_bwd_finalize_general_kernel; - - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - - kernel_final<<>>(launch_params.params); - } + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &ln_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -186,7 +157,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); @@ -194,11 +165,11 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);// +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 2, 1, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); @@ -252,7 +223,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 1, 1, 16, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); @@ -324,7 +295,7 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 4, 1, 16, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); @@ -346,13 +317,13 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 2, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 32, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); @@ -361,4 +332,4 @@ REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp32, fp32, fp32, fp3 REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); \ No newline at end of file diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 507bfcab0..1f2045684 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -40,10 +40,7 @@ static void launch_tuned_(LaunchParams &launch_params, #endif return; } - std::cout<<"tuned fwd ctas_per_row:"<< CTAS_PER_ROW<= 48 * 1024) { NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -107,9 +104,7 @@ static void launch_general_(LaunchParams &launch_params, #endif return; } - // std::cout<<"warps_m:"< Date: Mon, 24 Nov 2025 04:34:28 +0000 Subject: [PATCH 10/10] Delete files --- tuning_tools/abs_do_fwd.sh | 34 ----------- tuning_tools/abs_readall.py | 66 ---------------------- tuning_tools/find_fast.py | 79 -------------------------- tuning_tools/launcher_ge.py | 109 ------------------------------------ 4 files changed, 288 deletions(-) delete mode 100644 tuning_tools/abs_do_fwd.sh delete mode 100644 tuning_tools/abs_readall.py delete mode 100644 tuning_tools/find_fast.py delete mode 100644 tuning_tools/launcher_ge.py diff --git a/tuning_tools/abs_do_fwd.sh b/tuning_tools/abs_do_fwd.sh deleted file mode 100644 index 55eb2cdd8..000000000 --- a/tuning_tools/abs_do_fwd.sh +++ /dev/null @@ -1,34 +0,0 @@ -set -euo pipefail - -cd .. -pip install . - -#等待编译结束 - -cd tests/cpp/build/ -rm -rf * -cmake .. -make - -# 运行 rocprof 并把输出既打印到屏幕又保存到临时文件 -ROCLOG=/tmp/rocprof.log -rocprof --stats ./operator/test_operator | tee "$ROCLOG" - -# 从 rocprof 输出中提取两组数字 Dimension(2048,12288) -shape_line=$(grep -m 1 'OperatorTest/NormTestSuite.TestNorm/LayerNorm_' "$ROCLOG") -dim1=$(awk -F'X' '{print $3}' <<<"$shape_line") -dim2=$(awk -F'X' '{print $4}' <<<"$shape_line") - -# 再提取 ctas_per_row, warps_n, bytes_per_load -ctas=$(grep -m 1 'ctas_per_row:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') -wm=$(grep -m 1 'warps_m:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') -wn=$(grep -m 1 'warps_n:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') -bpl=$(grep -m 1 'bytes_per_load:' "$ROCLOG" | awk -F: '{gsub(/ /,"",$2); print $2}') - -# 拼成文件名并创建空文件 -filename="${dim1}_${dim2}_${ctas}_${wm}_${wn}_${bpl}" -# filename="${dim1}_${dim2}_${wm}_${wn}_${bpl}" -touch "/home/tuned_fwd/768/f16f16/$filename" -echo "→ Created file $filename" - -python /home/TransformerEngine/tuning_tools/abs_readall.py "/home/tuned_fwd/768/f16f16/${filename}" \ No newline at end of file diff --git a/tuning_tools/abs_readall.py b/tuning_tools/abs_readall.py deleted file mode 100644 index b8cff57c3..000000000 --- a/tuning_tools/abs_readall.py +++ /dev/null @@ -1,66 +0,0 @@ -import json -import os -import sys -import argparse - -def extract_and_process_durations(input_file, output_file, kernel_keywords, num_warmup, num_iteration): - with open(input_file, "r") as f: - data = json.load(f) - - keyword_to_durations = {k: [] for k in kernel_keywords} - - for event in data.get("traceEvents", []): - args = event.get("args", {}) - kernel_name = args.get("KernelName", "") - duration = args.get("DurationNs") - - if duration is not None: - for keyword in kernel_keywords: - if keyword in kernel_name: - keyword_to_durations[keyword].append(int(duration)) - break # 防止同一个event被多个keyword重复统计 - - output_lines = [] - - for keyword in kernel_keywords: - durations = keyword_to_durations[keyword] - output_lines.append(f"== {keyword} ==") - - if not durations: - output_lines.append("[无数据]") - continue - - i = 0 - while i < len(durations): - i += num_warmup # 跳过warmup - batch = [] - for _ in range(num_iteration): - if i < len(durations): - batch.append(durations[i]) - i += 1 - if batch: - avg = sum(batch) / len(batch) - output_lines.append(f"{avg:.2f}") - output_lines.append("") # 空行分隔 - - with open(output_file, "w") as f: - f.write("\n".join(output_lines)) - - print(f"已将所有 kernel 的平均耗时写入 {output_file}") - -input_json = "/home/TransformerEngine/tests/cpp/build/results.json" -if len(sys.argv) > 1: - output_txt = sys.argv[1] -else: - output_txt = "/home/bwdprofiles/tmp/heyi.txt" - -kernel_keywords = [ - "ln_fwd_", - "ln_bwd_general_kernel", - "ln_bwd_finalize" -] - -num_warmup = 5 -num_iteration = 10 - -extract_and_process_durations(input_json, output_txt, kernel_keywords, num_warmup, num_iteration) \ No newline at end of file diff --git a/tuning_tools/find_fast.py b/tuning_tools/find_fast.py deleted file mode 100644 index 5ccf12413..000000000 --- a/tuning_tools/find_fast.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3 -""" -脚本:遍历指定目录下所有文件,解析每个文件中 -- ln_fwd_ kernel 的时间之和 -- 将 ln_bwd_tuned_kernel 和 ln_bwd_finalize 两个 kernel 的时间之和合并为一个值 -然后在所有文件中分别找出 ln_fwd_ 和合并后的 bwd 的最小值及对应文件,输出结果。 -""" -import os -import sys -import re - -def parse_file(filepath): - """解析单个文件,返回 dict: 'ln_fwd_' -> sum, 'ln_bwd_total' -> combined sum""" - sums = {} - current = None - times = [] - header_pat = re.compile(r"^==\s*(.+?)\s*==$") - with open(filepath, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if not line: - if current and times: - sums[current] = sum(times) - times = [] - continue - m = header_pat.match(line) - if m: - current = m.group(1) - times = [] - else: - try: - times.append(float(line)) - except ValueError: - pass - if current and times: - sums[current] = sum(times) - # 合并后两个 bwd kernels - bwd_sum = sums.get('ln_bwd_tuned_kernel', 0) + sums.get('ln_bwd_finalize', 0) - # 返回只有两项 - return { - 'ln_fwd_': sums.get('ln_fwd_', float('inf')), - 'ln_bwd_total': bwd_sum - } - -def find_minimums(dirpath): - """遍历目录文件,返回 dict: key -> (min_sum, filepath)""" - results = {} - for name in os.listdir(dirpath): - fp = os.path.join(dirpath, name) - if not os.path.isfile(fp): - continue - file_sums = parse_file(fp) - for key, val in file_sums.items(): - if key not in results or val < results[key][0]: - results[key] = (val, fp) - return results - -def main(): - if len(sys.argv) != 2: - print(f"Usage: {sys.argv[0]} ") - sys.exit(1) - d = sys.argv[1] - if not os.path.isdir(d): - print(f"Error: {d} is not a directory") - sys.exit(1) - mins = find_minimums(d) - if not mins: - print("No valid files found.") - return - print("最小时间和结果:") - for key in ['ln_fwd_', 'ln_bwd_total']: - val, fp = mins.get(key, (None, None)) - if val is None: - print(f"- {key}: 无数据") - else: - print(f"- {key}: {val:.2f} 文件: {fp}") - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/tuning_tools/launcher_ge.py b/tuning_tools/launcher_ge.py deleted file mode 100644 index 6a0f00357..000000000 --- a/tuning_tools/launcher_ge.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -""" -脚本:针对指定 HIDDEN_SIZE/WTYPE/ITYPE/OTYPE/CTYPE,在 ln_fwd_cuda_kernel.cu 中批量替换 REGISTER_NORM_LAUNCHER 宏的 -CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG 四个参数组合。 -只替换匹配该前缀的行,保留其他注册宏不变。 -""" -import re,os -import subprocess - -# 需要替换的源文件路径 -SOURCE_FILE = '/home/TransformerEngine/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu' -RESULTS_DIR = '/home/tuned_fwd/768/f16f16' -# 隐藏大小列表 -hidden_sizes = [768] -# 构造前缀模板,format 时填入 hidden_size -PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, {hs}, fp16, fp16, fp16, fp32," -# PREFIX_TMPL = "REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, {hs}, fp16, fp16, fp16, fp32," - -# # 要测试的参数组合 -# ctas_per_row_list = [ 2] -# warps_m_list = [1] -# warps_n_list = [8] -# bytes_per_ldg_list= [4,8,16,32] - -ctas_per_row_list = [1] -warps_m_list = [2,1] -warps_n_list = [2,4,8] -bytes_per_ldg_list= [8,16] -# 批量替换 -for hs in hidden_sizes: - # 每个 hidden_size 生成对应前缀 - prefix = PREFIX_TMPL.format(hs=hs) - for ctas in ctas_per_row_list: - for wm in warps_m_list: - for wn in warps_n_list: - for bpl in bytes_per_ldg_list: - if wm * wn < 2: - continue - lhs = hs // (bpl // 2) - rhs = ctas * wn * 32 * (lhs // (ctas * wn * 32)) - # rhs = 1 * wn * 32 * (lhs // (1 * wn * 32)) - if lhs != rhs: - continue - # if not (ctas == 1 or wm == 1): - # continue - # 构造新的完整宏调用行 - new_line = f"{prefix} {ctas}, {wm}, {wn}, {bpl});"#bwd - # 读取源文件 - with open(SOURCE_FILE, 'r', encoding='utf-8') as f: - lines = f.readlines() - # 写回时替换匹配前缀的行 - with open(SOURCE_FILE, 'w', encoding='utf-8') as f: - for line in lines: - if line.strip().startswith(prefix): - f.write(new_line + '\n') - else: - f.write(line) - print(f"Updated {SOURCE_FILE} for hidden_size={hs} with: WARPS_M={wm}, WARPS_N={wn}, BYTES_PER_LDG={bpl}") - - result=subprocess.run(['bash', './abs_do_fwd.sh']) - if result.returncode != 0: - print(f"Warning: abs_do.sh failed with exit code {result.returncode}") - - -proc = subprocess.run( - ['python3', 'find_fast.py', RESULTS_DIR], - stdout=subprocess.PIPE, - text=True, - check=True -) - -best_fp = None -for line in proc.stdout.splitlines(): - if line.startswith('- ln_fwd_'): - # 解析 “文件: /path/to/2048_12288_1_1_8_32” - parts = line.split('文件:') - if len(parts) == 2: - best_fp = parts[1].strip() - break - -if not best_fp: - print("Error: 没有找到最佳 ln_fwd_ 结果,退出。") - sys.exit(1) - -best_name = os.path.basename(best_fp) # e.g. "2048_12288_1_1_8_32" -print("Best ln_fwd file:", best_name) - -# —— 3. 从文件名拆出参数,并在 .cu 中替换宏行 —— # -tokens = best_name.split('_') -if len(tokens) != 6: - print("Error: 无法解析文件名参数:", best_name) - sys.exit(1) - -hs2, n2, ctas2, wm2, wn2, bpl2 = tokens -prefix = PREFIX_TMPL.format(hs=hs2) -new_line = f"{prefix} {ctas2}, {wm2}, {wn2}, {bpl2});" - -# 读源文件、替换所有匹配 prefix 的行 -with open(SOURCE_FILE, 'r', encoding='utf-8') as f: - lines = f.readlines() -with open(SOURCE_FILE, 'w', encoding='utf-8') as f: - for line in lines: - if line.strip().startswith(prefix): - f.write(new_line + '\n') - else: - f.write(line) - -print("已将所有前缀行替换为最佳组合:") -print(" ", new_line) \ No newline at end of file