Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b1d2a0b
Increase 'warmup' and 'rep' for FA benchmark
anmyachev Sep 16, 2024
339b709
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 16, 2024
5ebbd01
Use 150ms
anmyachev Sep 16, 2024
b1cc599
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 17, 2024
0ad146f
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 17, 2024
bbf0557
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 19, 2024
81fec9a
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 23, 2024
42e653a
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 23, 2024
8f81c13
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 23, 2024
5d08d3a
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 29, 2024
b2d3398
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 29, 2024
fe806b1
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
bf49b0d
fix after merge
anmyachev Sep 30, 2024
7493632
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
524f81d
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
4d40864
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
b0d91ce
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
6809b9a
Merge remote-tracking branch 'origin' into amyachev/bench-time
anmyachev Oct 14, 2024
e1c4f9f
Change do_bench* signatures
anmyachev Oct 14, 2024
a1fd0f9
cleanup
anmyachev Oct 14, 2024
f16b149
fixes
anmyachev Oct 14, 2024
565d87c
fix
anmyachev Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 43 additions & 46 deletions benchmarks/triton_kernels_benchmark/benchmark_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ def _summarize_statistics(times, quantiles, return_mode):
return getattr(torch, return_mode)(times).item()


def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
def do_bench_ipex(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
sync_submitting=True, kernel_name=None): # pylint: disable=unused-argument
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.

:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param n_warmup: Number of repetitions for warmup
:type n_warmup: int
:param n_repeat: Number of repetitions to collect measurements
:type n_repeat: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
Expand All @@ -69,20 +69,6 @@ def do_bench_ipex(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, ret
cache_size = 256 * 1024 * 1024
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)

# Estimate the runtime of the function
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
Comment on lines -83 to -85
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no point in calculating the number of iterations through the expected time of one iteration, since the required number of iterations is requested by the user.

# Warm-up
for _ in range(n_warmup):
fn()
Expand Down Expand Up @@ -121,18 +107,18 @@ def extract_kernels(funcs):
return _summarize_statistics(times, quantiles, return_mode)


def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", device="xpu",
kernel_name=None): # pylint: disable=unused-argument
def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None, return_mode="mean",
device="xpu", kernel_name=None): # pylint: disable=unused-argument
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.

:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param n_warmup: Number of repetitions for warmup
:type n_warmup: int
:param n_repeat: Number of repetitions to collect measurements
:type n_repeat: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
Expand All @@ -142,24 +128,49 @@ def do_bench_elapsed_time(fn, warmup=25, rep=100, grad_to_none=None, quantiles=N
import torch
from triton.testing import do_bench as triton_do_bench

times = triton_do_bench(fn, warmup=warmup, rep=rep, grad_to_none=grad_to_none, return_mode="all",
# We maintain a buffer of 256 MB that we clear
# before each kernel call to make sure that the L2
# doesn't contain any input data before the run
cache_size = 256 * 1024 * 1024
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)

# Estimate the runtime of the function
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# The cache is also maintained in `triton_do_bench` function,
# there is no need to duplicate the amount of memory used.
del cache

# compute warmup and repeat times
warmup_time = n_warmup * estimate_ms
rep_time = n_repeat * estimate_ms
Comment on lines +152 to +154
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I translate the parameters into those that upstream (triton_do_bench) understands.


times = triton_do_bench(fn, warmup=warmup_time, rep=rep_time, grad_to_none=grad_to_none, return_mode="all",
device_type=device)
times = torch.tensor(times, dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)


def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean",
device="xpu", sync_submitting=True, kernel_name=None):
def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None,
return_mode="mean", device="xpu", sync_submitting=True, kernel_name=None):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.

:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param n_warmup: Number of repetitions for warmup
:type n_warmup: int
:param n_repeat: Number of repetitions to collect measurements
:type n_repeat: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
Expand All @@ -179,20 +190,6 @@ def do_bench_upstream_pytorch_profiler(fn, warmup=25, rep=100, grad_to_none=None
cache_size = 256 * 1024 * 1024
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device=device)

# Estimate the runtime of the function
start_event = torch.xpu.Event(enable_timing=True)
end_event = torch.xpu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
Comment on lines -193 to -195
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no point in calculating the number of iterations through the expected time of one iteration, since the required number of iterations is requested by the user.

# Warm-up
for _ in range(n_warmup):
fn()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
CAUSAL, scale=sm_scale), warmup=10, rep=10,
CAUSAL, scale=sm_scale), m_warmup=10, n_repeat=10,
quantiles=quantiles)

elif provider == 'triton':
Expand All @@ -256,7 +256,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='_attn_fwd')

elif provider == 'xetla':
Expand All @@ -272,7 +272,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='gpu::xetla::fmha::FmhaForwardKernel<')

else:
Expand Down
10 changes: 5 additions & 5 deletions benchmarks/triton_kernels_benchmark/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,18 @@ def benchmark(M, N, provider):
quantiles = [0.5, 0.0, 1.0]
if provider == "torch-native":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles,
warmup=10, rep=10)
n_warmup=10, n_repeat=10)
if provider == "triton":
out = torch.empty_like(x, device="xpu")
triton_fn = lambda: softmax(x, out)
torch_fn = lambda: torch.softmax(x, axis=-1)
benchmark_suit.assert_close(triton_fn(), torch_fn(), err_msg="triton to torch")
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, warmup=10, rep=10,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
kernel_name="softmax_kernel")

elif provider == "torch-jit":
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles, warmup=10,
rep=10)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: naive_softmax(x), quantiles=quantiles,
n_warmup=10, n_repeat=10)

elif provider == "xetla":
name = f"softmax_shape_{M}_{N}"
Expand All @@ -154,7 +154,7 @@ def benchmark(M, N, provider):
"softmax_shape_4096_16384": "mat1_4096x16k_bf16_cfg0",
"softmax_shape_4096_32768": "mat1_4096x32k_bf16_cfg0",
}
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, warmup=10, rep=10,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, quantiles=quantiles, n_warmup=10, n_repeat=10,
kernel_name=kernels_name[name])

else:
Expand Down
9 changes: 5 additions & 4 deletions benchmarks/triton_kernels_benchmark/gemm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def benchmark(B, M, N, K, provider):
if BENCHMARKING_METHOD == 'PYTORCH_LEGACY_PROFILER_USING_IPEX':
# Legacy profiler shows ~6000TFLOPS GeoMean for onednn measurements, so use more reliable method
do_bench = do_bench_elapsed_time
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), warmup=10, rep=10,
_, min_ms, max_ms, mean_ms, cv = do_bench(lambda: torch.matmul(torch_a, torch_b), n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name='gemm_kernel')
elif provider == 'triton':
assert len(a.shape) == len(b.shape), 'Incompatible sizes'
Expand All @@ -296,7 +296,8 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(torch_a, torch_b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name='matmul_kernel_with_block_pointers')
elif provider == 'xetla':
if B == 1:
Expand Down Expand Up @@ -340,8 +341,8 @@ def benchmark(B, M, N, K, provider):
}

# benchmark_suit.assert_close(xetla_fn(), torch_fn(), atol=1e-4, rtol=1.0, err_msg='xetla to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernels_name[name])
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernels_name[name])
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(a, b).to(torch.float32) + d
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernel_name)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.nn.functional.gelu(torch.matmul(a, b).to(torch.float32))
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernel_name)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def benchmark(B, M, N, K, provider):
torch_fn = lambda: torch.matmul(torch.exp(a), b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
kernel_name=kernel_name)
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles, kernel_name=kernel_name)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def benchmark(M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device='xpu', dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
rtol = 1e-2 if a.dtype == torch.bfloat16 else 1e-3
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
kernel_name='_kernel')
else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,15 @@ def benchmark(M, N, K, provider):
quantiles = [0.5, 0.0, 1.0]

if provider == 'onednn':
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), warmup=10, rep=10,
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(lambda: torch.matmul(a, b), n_warmup=10, n_repeat=10,
quantiles=quantiles)
elif provider == 'triton':
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
triton_fn = lambda: matmul(a, b, c)
torch_fn = lambda: torch.matmul(a, b).to(torch.float32)
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=1e-2, err_msg='triton to torch')
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
_, min_ms, max_ms, mean_ms, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10,
quantiles=quantiles,
kernel_name=['first_wave', 'full_tiles'])
else:
raise NotImplementedError(f'Unsupported provider {provider}')
Expand Down