Skip to content

Commit 022d974

Browse files
committed
pass a function for autotuner via 'do_bench' param
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent c4a40f4 commit 022d974

File tree

11 files changed

+23
-5
lines changed

11 files changed

+23
-5
lines changed

benchmarks/triton_kernels_benchmark/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401
1+
from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, BENCHMARKING_METHOD # type: ignore # noqa: F401
22

33
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
44
from triton.runtime import driver

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import itertools
33
import os
44
from typing import Any, Dict, List
5-
import triton
65

76
BENCHMARKING_METHOD = os.getenv("BENCHMARKING_METHOD", "UPSTREAM_PYTORCH_PROFILER")
87

@@ -160,7 +159,13 @@ def extract_kernels(funcs):
160159
else:
161160
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")
162161

163-
triton.testing.do_bench = do_bench
162+
163+
def make_do_bench_for_autotune(kernel_name: str):
164+
165+
def autotuner_do_bench(*args, **kwargs):
166+
return do_bench(*args, n_warmup=10, n_repeat=10, kernel_name=kernel_name, **kwargs)
167+
168+
return autotuner_do_bench
164169

165170

166171
def assert_close(x, y, atol=None, rtol=None, err_msg=""):

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
161161
for w in [8, 16, 32] \
162162
]
163163

164-
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
164+
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'],
165+
do_bench=benchmark_suit.make_do_bench_for_autotune('_attn_fwd'))
165166
tune_attn_fwd = tuner(_attn_fwd)
166167

167168

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def naive_softmax(x):
5050
triton.Config({"threads_per_warp": 16}, num_warps=4),
5151
],
5252
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
53+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name="softmax_kernel"),
5354
)
5455
@triton.jit
5556
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
num_stages=s, num_warps=32) for s in [2, 3]
4040
],
4141
key=['M', 'N', 'K'],
42+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
4243
)
4344
@triton.jit
4445
def matmul_kernel_with_block_pointers(
@@ -112,6 +113,7 @@ def matmul_kernel_with_block_pointers(
112113
num_stages=s, num_warps=4) for s in [2]
113114
],
114115
key=['M', 'N', 'K'],
116+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
115117
)
116118
@triton.jit
117119
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
num_stages=2, num_warps=32),
3333
],
3434
key=['M', 'N', 'K'],
35+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
3536
)
3637
@triton.jit
3738
def matmul_kernel_with_block_pointers(
@@ -106,6 +107,7 @@ def matmul_kernel_with_block_pointers(
106107
num_stages=2, num_warps=4),
107108
],
108109
key=['M', 'N', 'K'],
110+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
109111
)
110112
@triton.jit
111113
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def gelu(x):
5151
num_stages=2, num_warps=32),
5252
],
5353
key=['M', 'N', 'K'],
54+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
5455
)
5556
@triton.jit
5657
def matmul_kernel_with_block_pointers(
@@ -119,6 +120,7 @@ def matmul_kernel_with_block_pointers(
119120
num_stages=2, num_warps=4),
120121
],
121122
key=['M', 'N', 'K'],
123+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
122124
)
123125
@triton.jit
124126
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
num_stages=2, num_warps=32),
3434
],
3535
key=['M', 'N', 'K'],
36+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
3637
)
3738
@triton.jit
3839
def matmul_kernel_with_block_pointers(
@@ -104,6 +105,7 @@ def matmul_kernel_with_block_pointers(
104105
num_stages=2, num_warps=4),
105106
],
106107
key=['M', 'N', 'K'],
108+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
107109
)
108110
@triton.jit
109111
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
num_stages=4, num_warps=32),
1313
],
1414
key=['M', 'N', 'K'],
15+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='_kernel'),
1516
)
1617
@triton.jit
1718
def _kernel(A, B, C, #

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def mac_loop(
104104
num_stages=2, num_warps=32),
105105
],
106106
key=['M', 'N', 'K'],
107+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='first_wave'),
107108
)
108109
@triton.jit
109110
def first_wave(
@@ -140,6 +141,7 @@ def first_wave(
140141
num_stages=2, num_warps=32),
141142
],
142143
key=['M', 'N', 'K'],
144+
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='full_tiles'),
143145
)
144146
@triton.jit
145147
def full_tiles(

0 commit comments

Comments
 (0)