diff --git a/benchmarks/triton_kernels_benchmark/__init__.py b/benchmarks/triton_kernels_benchmark/__init__.py index a340915d9c..a2cddd2ab0 100644 --- a/benchmarks/triton_kernels_benchmark/__init__.py +++ b/benchmarks/triton_kernels_benchmark/__init__.py @@ -1,12 +1,13 @@ import os -from .benchmark_testing import assert_close, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD +from .benchmark_testing import assert_close, make_do_bench_for_autotune, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER": os.environ["INJECT_PYTORCH"] = "True" __all__ = [ "assert_close", + "make_do_bench_for_autotune", "do_bench", "perf_report", "Benchmark", diff --git a/benchmarks/triton_kernels_benchmark/benchmark_testing.py b/benchmarks/triton_kernels_benchmark/benchmark_testing.py index 109f3f4f1b..066d6f5791 100644 --- a/benchmarks/triton_kernels_benchmark/benchmark_testing.py +++ b/benchmarks/triton_kernels_benchmark/benchmark_testing.py @@ -1,6 +1,7 @@ import argparse import itertools import os +import time from triton.testing import assert_close as triton_assert_close, Benchmark @@ -172,6 +173,39 @@ def extract_kernels(funcs): raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented") +def make_do_bench_for_autotune(): + import triton + + def autotuner_do_bench(fn, *args, **kwargs): + di = triton.runtime.driver.active.get_device_interface() + + fn() + di.synchronize() + + cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() + + count = 5 + start = time.time_ns() / 1_000_000 + for _ in range(count): + triton.runtime.driver.active.clear_cache(cache) + fn() + di.synchronize() + end = time.time_ns() / 1_000_000 + estimate_ms = (end - start) / count + + # defaults for `do_bench` in ms + warmup_time = 25 + rep_time = 100 + + # compute n_warmup and n_repeat times + n_warmup = max(1, int(warmup_time / estimate_ms)) + n_repeat = max(1, int(rep_time / estimate_ms)) + + return do_bench(fn, *args, n_warmup=n_warmup, n_repeat=n_repeat, **kwargs) + + return autotuner_do_bench + + def assert_close(x_fn, y_fn, atol=None, rtol=None, err_msg=""): if VERIFY: triton_assert_close(x_fn(), y_fn(), atol, rtol, err_msg) diff --git a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py index ef40c3b507..d6c5973431 100644 --- a/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py @@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, # for w in [8, 16, 32] \ ] -tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL']) +tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune()) tune_attn_fwd = tuner(_attn_fwd) diff --git a/benchmarks/triton_kernels_benchmark/fused_softmax.py b/benchmarks/triton_kernels_benchmark/fused_softmax.py index 799bdb1b53..dfcc2c0681 100644 --- a/benchmarks/triton_kernels_benchmark/fused_softmax.py +++ b/benchmarks/triton_kernels_benchmark/fused_softmax.py @@ -50,6 +50,7 @@ def naive_softmax(x): triton.Config({"threads_per_warp": 16}, num_warps=4), ], key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr, diff --git a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py index 97dc321067..0b7747ec9d 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_benchmark.py @@ -39,6 +39,7 @@ num_stages=s, num_warps=32) for s in [2, 3] ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -112,6 +113,7 @@ def matmul_kernel_with_block_pointers( num_stages=s, num_warps=4) for s in [2] ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py index f4964e6d47..8a55f47603 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py @@ -52,6 +52,7 @@ def suffix(): num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -127,6 +128,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py index 65593f731e..f811ab4ddb 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py @@ -51,6 +51,7 @@ def gelu(x): num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -119,6 +120,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py index 9b34558593..3425c1385b 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py @@ -33,6 +33,7 @@ num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers( @@ -104,6 +105,7 @@ def matmul_kernel_with_block_pointers( num_stages=2, num_warps=4), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def matmul_kernel_with_block_pointers_batched( diff --git a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py index 904d426556..d7883e636b 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py @@ -12,6 +12,7 @@ num_stages=4, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def _kernel(A, B, C, # diff --git a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py index 29bd68698e..f8fa2d0b54 100644 --- a/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py +++ b/benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py @@ -104,6 +104,7 @@ def mac_loop( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def first_wave( @@ -140,6 +141,7 @@ def first_wave( num_stages=2, num_warps=32), ], key=['M', 'N', 'K'], + do_bench=benchmark_suit.make_do_bench_for_autotune(), ) @triton.jit def full_tiles( diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index fbc9ac4464..971e0b7614 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -361,7 +361,7 @@ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr): def decorator(fn): return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook, post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep, - use_cuda_graph=use_cuda_graph) + use_cuda_graph=use_cuda_graph, do_bench=do_bench) return decorator