From c4a40f48192bf823ed205ec5b9be9e5300097b02 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 17 Dec 2024 22:24:14 +0000 Subject: [PATCH 01/12] Don't use implicitly 'elapsed_time' in autotuner Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 3 +++ python/triton/runtime/autotuner.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 8fb4470bc7..393c472b83 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -2,6 +2,7 @@ import itertools import os from typing import Any, Dict, List +import triton BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") @@ -159,6 +160,8 @@ def extract_kernels(funcs): else: raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") +triton.testing.do_bench = do_bench + def assert_close(x, y, atol=None, rtol=None, err_msg=""): import numpy as np diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index cf32451c5e..5e0023543d 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -117,8 +117,8 @@ def _post_hook(kwargs, exception): import triton.testing self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( kernel_call, - warmup=warmup if warmup is not None else 25, - rep=rep if rep is not None else 100, + warmup=warmup if warmup is not None else 10, + rep=rep if rep is not None else 10, quantiles=quantiles, ) return From 022d974b57ecf8ada74d11609965bdbc3ac3cc2f Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 18 Dec 2024 11:33:28 +0000 Subject: [PATCH 02/12] pass a function for autotuner via 'do_bench' param Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/__init__.py | 2 +- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 9 +++++++-- .../flash_attention_fwd_benchmark.py | 3 ++- benchmarks/triton_kernels_benchmark/fused_softmax.py | 1 + benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 2 ++ .../gemm_postop_addmatrix_benchmark.py | 2 ++ .../gemm_postop_gelu_benchmark.py | 2 ++ .../triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 2 ++ .../triton_kernels_benchmark/gemm_splitk_benchmark.py | 1 + .../triton_kernels_benchmark/gemm_streamk_benchmark.py | 2 ++ python/triton/runtime/autotuner.py | 2 +- 11 files changed, 23 insertions(+), 5 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index 820fff61e0..1b41c91d50 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -1,4 +1,4 @@ -from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401 +from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401 if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": from triton.runtime import driver diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 393c472b83..0f1d359aad 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -2,7 +2,6 @@ import itertools import os from typing import Any, Dict, List -import triton BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") @@ -160,7 +159,13 @@ def extract_kernels(funcs): else: raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") -triton.testing.do_bench = do_bench + +def make_do_bench_for_autotune(kernel_name: str): + + def autotuner_do_bench(*args, **kwargs): + return do_bench(*args, n_warmup=10, n_repeat=10, kernel_name=kernel_name, **kwargs) + + return autotuner_do_bench def assert_close(x, y, atol=None, rtol=None, err_msg=""): diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index b672d70c8a..1eec44dfbc 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -161,7 +161,8 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # for w in [8, 16, 32] \ ] -tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL']) +tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], + do_bench=benchmark_suit.make_do_bench_for_autotune('_attn_fwd')) tune_attn_fwd = tuner(_attn_fwd) diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 6782e92d6b..e7ca178c17 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -50,6 +50,7 @@ def naive_softmax(x): triton.Config({"threads_per_warp": 16}, num_warps=4), ], key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name="softmax_kernel"), ) @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr, diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 860595a06b..e36ade2ae1 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -39,6 +39,7 @@ num_stages=s, num_warps=32) for s in [2, 3] ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -112,6 +113,7 @@ def matmul_kernel_with_block_pointers( num_stages=s, num_warps=4) for s in [2] ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 7d9d877660..53a9c46a03 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -32,6 +32,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -106,6 +107,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 0faeead793..79de601d9a 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -51,6 +51,7 @@ def gelu(x): num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -119,6 +120,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 8b13827e0c..d2b9bd79cd 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -33,6 +33,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -104,6 +105,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index c4dd86d834..7a850d22d3 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -12,6 +12,7 @@ num_stages=4, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='_kernel'), ) @triton.jit def _kernel(A, B, C, # diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 6ef40be902..dbe1aa0086 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -104,6 +104,7 @@ def mac_loop( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='first_wave'), ) @triton.jit def first_wave( @@ -140,6 +141,7 @@ def first_wave( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='full_tiles'), ) @triton.jit def full_tiles( diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 5e0023543d..045e384027 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -361,7 +361,7 @@ def kernel(x_ptr, x_size, **META): def decorator(fn): return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph) + use_cuda_graph=use_cuda_graph, do_bench=do_bench) return decorator From de18b7acdd924dd6e0b148973aeace140cb9d9e4 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 18 Dec 2024 13:49:25 +0000 Subject: [PATCH 03/12] revert warmup/rep changes Signed-off-by: Anatoly Myachev --- python/triton/runtime/autotuner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 045e384027..83e104c795 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -117,8 +117,8 @@ def _post_hook(kwargs, exception): import triton.testing self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( kernel_call, - warmup=warmup if warmup is not None else 10, - rep=rep if rep is not None else 10, + warmup=warmup if warmup is not None else 25, + rep=rep if rep is not None else 100, quantiles=quantiles, ) return From e82f8739f7bfe7d371f0ecfd5be268d2a5174dbd Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 18 Dec 2024 18:36:11 +0100 Subject: [PATCH 04/12] Apply suggestions from code review Co-authored-by: Whitney Tsang --- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 2 +- .../triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py | 2 +- .../triton_kernels_benchmark/gemm_postop_gelu_benchmark.py | 2 +- benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index e36ade2ae1..acf6fd42e3 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -113,7 +113,7 @@ def matmul_kernel_with_block_pointers( num_stages=s, num_warps=4) for s in [2] ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index 53a9c46a03..bf95753122 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -107,7 +107,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 79de601d9a..0819365802 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -120,7 +120,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index d2b9bd79cd..fe53fab4c7 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -105,7 +105,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), ) @triton.jit def matmul_kernel_with_block_pointers_batched( From 5710fd1974cab092aa701af4242ec201d3557f31 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sat, 21 Dec 2024 11:05:09 +0000 Subject: [PATCH 05/12] remove 'kernel_name' Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 4 ++-- .../triton_kernels_benchmark/flash_attention_fwd_benchmark.py | 3 +-- benchmarks/triton_kernels_benchmark/fused_softmax.py | 2 +- benchmarks/triton_kernels_benchmark/gemm_benchmark.py | 4 ++-- .../gemm_postop_addmatrix_benchmark.py | 4 ++-- .../triton_kernels_benchmark/gemm_postop_gelu_benchmark.py | 4 ++-- .../triton_kernels_benchmark/gemm_preop_exp_benchmark.py | 4 ++-- benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py | 2 +- benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py | 4 ++-- 9 files changed, 15 insertions(+), 16 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 0f1d359aad..7e955889fb 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -160,10 +160,10 @@ def extract_kernels(funcs): raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") -def make_do_bench_for_autotune(kernel_name: str): +def make_do_bench_for_autotune(): def autotuner_do_bench(*args, **kwargs): - return do_bench(*args, n_warmup=10, n_repeat=10, kernel_name=kernel_name, **kwargs) + return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs) return autotuner_do_bench diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py index 1eec44dfbc..6a9c5d6880 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py @@ -161,8 +161,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # for w in [8, 16, 32] \ ] -tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], - do_bench=benchmark_suit.make_do_bench_for_autotune('_attn_fwd')) +tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune()) tune_attn_fwd = tuner(_attn_fwd) diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index e7ca178c17..56cd91befe 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -50,7 +50,7 @@ def naive_softmax(x): triton.Config({"threads_per_warp": 16}, num_warps=4), ], key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name="softmax_kernel"), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr, diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index acf6fd42e3..b1c460ef5d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -39,7 +39,7 @@ num_stages=s, num_warps=32) for s in [2, 3] ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -113,7 +113,7 @@ def matmul_kernel_with_block_pointers( num_stages=s, num_warps=4) for s in [2] ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index bf95753122..b2d1e78361 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -32,7 +32,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -107,7 +107,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 0819365802..430a7be241 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -51,7 +51,7 @@ def gelu(x): num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -120,7 +120,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index fe53fab4c7..8f55986f17 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -33,7 +33,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -105,7 +105,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 7a850d22d3..0953294822 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -12,7 +12,7 @@ num_stages=4, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='_kernel'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def _kernel(A, B, C, # diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index dbe1aa0086..c26bc94a3d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -104,7 +104,7 @@ def mac_loop( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='first_wave'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def first_wave( @@ -141,7 +141,7 @@ def first_wave( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], - do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='full_tiles'), + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def full_tiles( From 25ec4a6e3850f55b80d0b3cdb5ae97b6f7b62264 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 14 Jan 2025 13:22:22 +0100 Subject: [PATCH 06/12] calculate the number of iterations based on time Signed-off-by: Anatoly Myachev --- .../benchmark_testing.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 7e955889fb..c54e37095b 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -1,6 +1,7 @@ import argparse import itertools import os +import time from typing import Any, Dict, List BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER") @@ -161,9 +162,26 @@ def extract_kernels(funcs): def make_do_bench_for_autotune(): + import triton - def autotuner_do_bench(*args, **kwargs): - return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs) + def autotuner_do_bench(fn, *args, **kwargs): + di = triton.runtime.driver.active.get_device_interface() + + start = time.time_ns() / 1_000_000 + fn() + di.synchronize() + end = time.time_ns() / 1_000_000 + estimate_ms = end - start + + # defaults for `do_bench` in ms + warmup_time = 25 + rep_time = 100 + + # compute n_warmup and n_repeat times + n_warmup = int(warmup_time // estimate_ms + 1) + n_repeat = int(rep_time // estimate_ms + 1) + + return do_bench(fn, *args, n_warmup=n_warmup, n_repeat=n_repeat, **kwargs) return autotuner_do_bench From eb10a921f4c4e67ae6aa871a82645f6d86296c36 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 14 Jan 2025 15:20:50 +0100 Subject: [PATCH 07/12] TRITON_PRINT_AUTOTUNING Signed-off-by: Anatoly Myachev --- .github/workflows/triton-benchmarks.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 1c9dbc2b3b..22841ac711 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -45,6 +45,7 @@ permissions: read-all env: PYTHON_VERSION: "3.10" + TRITON_PRINT_AUTOTUNING: "1" BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER' }} TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }} From f044ec37b9bcf6d28be63632529884d4030f8213 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Tue, 14 Jan 2025 17:59:32 +0100 Subject: [PATCH 08/12] more runs for 'estimate_ms' calculation Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index c54e37095b..e7c0a945db 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -167,11 +167,13 @@ def make_do_bench_for_autotune(): def autotuner_do_bench(fn, *args, **kwargs): di = triton.runtime.driver.active.get_device_interface() + count = 3 start = time.time_ns() / 1_000_000 - fn() - di.synchronize() + for _ in range(count): + fn() + di.synchronize() end = time.time_ns() / 1_000_000 - estimate_ms = end - start + estimate_ms = (end - start) / count # defaults for `do_bench` in ms warmup_time = 25 From ddc6595223a1e9e7d22f81148c8a9d0535abbb0b Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 15 Jan 2025 18:13:22 +0100 Subject: [PATCH 09/12] align the procedure for getting estimate_ms with the procedure in do_bench Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 8cdf8a2e4e..38d363787d 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -168,10 +168,12 @@ def make_do_bench_for_autotune(): def autotuner_do_bench(fn, *args, **kwargs): di = triton.runtime.driver.active.get_device_interface() + cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() - count = 3 + count = 5 start = time.time_ns() / 1_000_000 for _ in range(count): + triton.runtime.driver.active.clear_cache(cache) fn() di.synchronize() end = time.time_ns() / 1_000_000 From f6a05b75b71d29e1f7af652197a924ccacc3e72e Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 15 Jan 2025 19:22:31 +0000 Subject: [PATCH 10/12] one extra call before Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 38d363787d..7b25665141 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -168,6 +168,10 @@ def make_do_bench_for_autotune(): def autotuner_do_bench(fn, *args, **kwargs): di = triton.runtime.driver.active.get_device_interface() + + fn() + di.synchronize() + cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() count = 5 From 44f3d02218567a444a6514efcee6b689d5e583c2 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Wed, 15 Jan 2025 19:24:23 +0000 Subject: [PATCH 11/12] align Signed-off-by: Anatoly Myachev --- benchmarks/triton_kernels_benchmark/benchmark_testing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 7b25665141..7160cf454a 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -188,8 +188,8 @@ def autotuner_do_bench(fn, *args, **kwargs): rep_time = 100 # compute n_warmup and n_repeat times - n_warmup = int(warmup_time // estimate_ms + 1) - n_repeat = int(rep_time // estimate_ms + 1) + n_warmup = max(1, int(warmup_time / estimate_ms)) + n_repeat = max(1, int(rep_time / estimate_ms)) return do_bench(fn, *args, n_warmup=n_warmup, n_repeat=n_repeat, **kwargs) From da4d163a0ea73fc20fca244ca0d8033df5f99886 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 16 Jan 2025 15:24:15 +0100 Subject: [PATCH 12/12] Update .github/workflows/triton-benchmarks.yml --- .github/workflows/triton-benchmarks.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index e30d12780f..0cfc9b1a2a 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -49,7 +49,6 @@ permissions: read-all env: PYTHON_VERSION: "3.10" - TRITON_PRINT_AUTOTUNING: "1" BENCHMARKING_METHOD: ${{ inputs.benchmarking_method || 'UPSTREAM_PYTORCH_PROFILER' }} VERIFY: ${{ (github.event_name == 'pull_request' || github.event_name == 'schedule' || inputs.verify) && '1' || '0' }} TAG: ${{ inputs.tag || (github.event_name == 'pull_request' && format('pr-{0}', github.event.number)) || (github.event_name == 'schedule' && 'ci') || 'test' }}