Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/triton_kernels_benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os

from .benchmark_testing import assert_close, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD
from .benchmark_testing import assert_close, make_do_bench_for_autotune, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD

if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
os.environ["INJECT_PYTORCH"] = "True"

__all__ = [
"assert_close",
"make_do_bench_for_autotune",
"do_bench",
"perf_report",
"Benchmark",
Expand Down
34 changes: 34 additions & 0 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import itertools
import os
import time

from triton.testing import assert_close as triton_assert_close, Benchmark

Expand Down Expand Up @@ -172,6 +173,39 @@ def extract_kernels(funcs):
raise NotImplementedError(f"BENCHMARKING_METHOD: {BENCHMARKING_METHOD} isn't implemented")


def make_do_bench_for_autotune():
import triton

def autotuner_do_bench(fn, *args, **kwargs):
di = triton.runtime.driver.active.get_device_interface()

fn()
di.synchronize()

cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()

count = 5
start = time.time_ns() / 1_000_000
for _ in range(count):
triton.runtime.driver.active.clear_cache(cache)
fn()
di.synchronize()
end = time.time_ns() / 1_000_000
estimate_ms = (end - start) / count

# defaults for `do_bench` in ms
warmup_time = 25
rep_time = 100

# compute n_warmup and n_repeat times
n_warmup = max(1, int(warmup_time / estimate_ms))
n_repeat = max(1, int(rep_time / estimate_ms))
Comment on lines +201 to +202
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iteration determination procedure is as similar as possible to the one used before. I believe the changes in the results can be fully attributed to the consequences of the transition from implicit elapsed_time timing to simple wall timing.


return do_bench(fn, *args, n_warmup=n_warmup, n_repeat=n_repeat, **kwargs)

return autotuner_do_bench


def assert_close(x_fn, y_fn, atol=None, rtol=None, err_msg=""):
if VERIFY:
triton_assert_close(x_fn(), y_fn(), atol, rtol, err_msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out, #
for w in [8, 16, 32] \
]

tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'], do_bench=benchmark_suit.make_do_bench_for_autotune())
tune_attn_fwd = tuner(_attn_fwd)


Expand Down
1 change: 1 addition & 0 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def naive_softmax(x):
triton.Config({"threads_per_warp": 16}, num_warps=4),
],
key=["BLOCK_SIZE_X", "BLOCK_SIZE_Y"],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE_X: tl.constexpr,
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
num_stages=s, num_warps=32) for s in [2, 3]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -112,6 +113,7 @@ def matmul_kernel_with_block_pointers(
num_stages=s, num_warps=4) for s in [2]
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def suffix():
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -127,6 +128,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def gelu(x):
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -119,6 +120,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers(
Expand Down Expand Up @@ -104,6 +105,7 @@ def matmul_kernel_with_block_pointers(
num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def matmul_kernel_with_block_pointers_batched(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
num_stages=4, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def _kernel(A, B, C, #
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def mac_loop(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def first_wave(
Expand Down Expand Up @@ -140,6 +141,7 @@ def first_wave(
num_stages=2, num_warps=32),
],
key=['M', 'N', 'K'],
do_bench=benchmark_suit.make_do_bench_for_autotune(),
)
@triton.jit
def full_tiles(
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def kernel(x_ptr, x_size, BLOCK_SIZE: tl.constexpr):
def decorator(fn):
return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, restore_value, pre_hook=pre_hook,
post_hook=post_hook, prune_configs_by=prune_configs_by, warmup=warmup, rep=rep,
use_cuda_graph=use_cuda_graph)
use_cuda_graph=use_cuda_graph, do_bench=do_bench)

return decorator

Expand Down