Skip to content

Commit 2a4b818

Browse files
committed
try changes from #3036
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 8cf03e9 commit 2a4b818

File tree

11 files changed

+23
-3
lines changed

11 files changed

+23
-3
lines changed

benchmarks/triton_kernels_benchmark/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .benchmark_testing import do_bench, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401
1+
from .benchmark_testing import do_bench, make_do_bench_for_autotune, assert_close, perf_report, Benchmark, USE_IPEX_OPTION, BENCHMARKING_METHOD # type: ignore # noqa: F401
22

33
if USE_IPEX_OPTION or BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
44
from triton.runtime import driver

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,14 @@ def extract_kernels(funcs):
237237
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")
238238

239239

240+
def make_do_bench_for_autotune():
241+
242+
def autotuner_do_bench(*args, **kwargs):
243+
return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs)
244+
245+
return autotuner_do_bench
246+
247+
240248
def assert_close(x, y, atol=None, rtol=None, err_msg=""):
241249
import numpy as np
242250
import torch

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
164164
for w in [8, 16, 32] \
165165
]
166166

167-
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
167+
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune())
168168
tune_attn_fwd = tuner(_attn_fwd)
169169

170170

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def naive_softmax(x):
5050
triton.Config({"threads_per_warp": 16}, num_warps=4),
5151
],
5252
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
53+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
5354
)
5455
@triton.jit
5556
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
num_stages=s, num_warps=32) for s in [2, 3]
4444
],
4545
key=['M', 'N', 'K'],
46+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
4647
)
4748
@triton.jit
4849
def matmul_kernel_with_block_pointers(
@@ -116,6 +117,7 @@ def matmul_kernel_with_block_pointers(
116117
num_stages=s, num_warps=4) for s in [2]
117118
],
118119
key=['M', 'N', 'K'],
120+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
119121
)
120122
@triton.jit
121123
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
num_stages=2, num_warps=32),
3636
],
3737
key=['M', 'N', 'K'],
38+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
3839
)
3940
@triton.jit
4041
def matmul_kernel_with_block_pointers(
@@ -109,6 +110,7 @@ def matmul_kernel_with_block_pointers(
109110
num_stages=2, num_warps=4),
110111
],
111112
key=['M', 'N', 'K'],
113+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
112114
)
113115
@triton.jit
114116
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def gelu(x):
5454
num_stages=2, num_warps=32),
5555
],
5656
key=['M', 'N', 'K'],
57+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
5758
)
5859
@triton.jit
5960
def matmul_kernel_with_block_pointers(
@@ -122,6 +123,7 @@ def matmul_kernel_with_block_pointers(
122123
num_stages=2, num_warps=4),
123124
],
124125
key=['M', 'N', 'K'],
126+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
125127
)
126128
@triton.jit
127129
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
num_stages=2, num_warps=32),
3737
],
3838
key=['M', 'N', 'K'],
39+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
3940
)
4041
@triton.jit
4142
def matmul_kernel_with_block_pointers(
@@ -107,6 +108,7 @@ def matmul_kernel_with_block_pointers(
107108
num_stages=2, num_warps=4),
108109
],
109110
key=['M', 'N', 'K'],
111+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
110112
)
111113
@triton.jit
112114
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
num_stages=4, num_warps=32),
1616
],
1717
key=['M', 'N', 'K'],
18+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
1819
)
1920
@triton.jit
2021
def _kernel(A, B, C, #

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def mac_loop(
107107
num_stages=2, num_warps=32),
108108
],
109109
key=['M', 'N', 'K'],
110+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
110111
)
111112
@triton.jit
112113
def first_wave(
@@ -143,6 +144,7 @@ def first_wave(
143144
num_stages=2, num_warps=32),
144145
],
145146
key=['M', 'N', 'K'],
147+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
146148
)
147149
@triton.jit
148150
def full_tiles(

0 commit comments

Comments
 (0)