Skip to content

Commit 5710fd1

Browse files
committed
remove 'kernel_name'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent e82f873 commit 5710fd1

9 files changed

+15
-16
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ def extract_kernels(funcs):
160160
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")
161161

162162

163-
def make_do_bench_for_autotune(kernel_name: str):
163+
def make_do_bench_for_autotune():
164164

165165
def autotuner_do_bench(*args, **kwargs):
166-
return do_bench(*args, n_warmup=10, n_repeat=10, kernel_name=kernel_name, **kwargs)
166+
return do_bench(*args, n_warmup=10, n_repeat=10, **kwargs)
167167

168168
return autotuner_do_bench
169169

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
161161
for w in [8, 16, 32] \
162162
]
163163

164-
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'],
165-
do_bench=benchmark_suit.make_do_bench_for_autotune('_attn_fwd'))
164+
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune())
166165
tune_attn_fwd = tuner(_attn_fwd)
167166

168167

benchmarks/triton_kernels_benchmark/fused_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def naive_softmax(x):
5050
triton.Config({"threads_per_warp": 16}, num_warps=4),
5151
],
5252
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
53-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name="softmax_kernel"),
53+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
5454
)
5555
@triton.jit
5656
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
num_stages=s, num_warps=32) for s in [2, 3]
4040
],
4141
key=['M', 'N', 'K'],
42-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
42+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
4343
)
4444
@triton.jit
4545
def matmul_kernel_with_block_pointers(
@@ -113,7 +113,7 @@ def matmul_kernel_with_block_pointers(
113113
num_stages=s, num_warps=4) for s in [2]
114114
],
115115
key=['M', 'N', 'K'],
116-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
116+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
117117
)
118118
@triton.jit
119119
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
num_stages=2, num_warps=32),
3333
],
3434
key=['M', 'N', 'K'],
35-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
35+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
3636
)
3737
@triton.jit
3838
def matmul_kernel_with_block_pointers(
@@ -107,7 +107,7 @@ def matmul_kernel_with_block_pointers(
107107
num_stages=2, num_warps=4),
108108
],
109109
key=['M', 'N', 'K'],
110-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
110+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
111111
)
112112
@triton.jit
113113
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_postop_gelu_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def gelu(x):
5151
num_stages=2, num_warps=32),
5252
],
5353
key=['M', 'N', 'K'],
54-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
54+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
5555
)
5656
@triton.jit
5757
def matmul_kernel_with_block_pointers(
@@ -120,7 +120,7 @@ def matmul_kernel_with_block_pointers(
120120
num_stages=2, num_warps=4),
121121
],
122122
key=['M', 'N', 'K'],
123-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
123+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
124124
)
125125
@triton.jit
126126
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_preop_exp_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
num_stages=2, num_warps=32),
3434
],
3535
key=['M', 'N', 'K'],
36-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers'),
36+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
3737
)
3838
@triton.jit
3939
def matmul_kernel_with_block_pointers(
@@ -105,7 +105,7 @@ def matmul_kernel_with_block_pointers(
105105
num_stages=2, num_warps=4),
106106
],
107107
key=['M', 'N', 'K'],
108-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='matmul_kernel_with_block_pointers_batched'),
108+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
109109
)
110110
@triton.jit
111111
def matmul_kernel_with_block_pointers_batched(

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
num_stages=4, num_warps=32),
1313
],
1414
key=['M', 'N', 'K'],
15-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='_kernel'),
15+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
1616
)
1717
@triton.jit
1818
def _kernel(A, B, C, #

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def mac_loop(
104104
num_stages=2, num_warps=32),
105105
],
106106
key=['M', 'N', 'K'],
107-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='first_wave'),
107+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
108108
)
109109
@triton.jit
110110
def first_wave(
@@ -141,7 +141,7 @@ def first_wave(
141141
num_stages=2, num_warps=32),
142142
],
143143
key=['M', 'N', 'K'],
144-
do_bench=benchmark_suit.make_do_bench_for_autotune(kernel_name='full_tiles'),
144+
do_bench=benchmark_suit.make_do_bench_for_autotune(),
145145
)
146146
@triton.jit
147147
def full_tiles(

0 commit comments

Comments
 (0)