Skip to content

Commit f69ba3e

Browse files
authored
[ENH] Assign benchmark parameters values at runtime instead of import … (#3934)
This is second PR in series to prepare for entry point introduction for `triton_kernels_benchmarks` package.
1 parent c4a6228 commit f69ba3e

File tree

5 files changed

+398
-279
lines changed

5 files changed

+398
-279
lines changed
Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
22

3-
from .benchmark_testing import assert_close, do_bench, perf_report, Benchmark, BENCHMARKING_METHOD
3+
from .benchmark_testing import (
4+
assert_close,
5+
do_bench,
6+
perf_report,
7+
Benchmark,
8+
BENCHMARKING_METHOD,
9+
filter_providers,
10+
)
411

512
if BENCHMARKING_METHOD == "UPSTREAM_PYTORCH_PROFILER":
613
os.environ["INJECT_PYTORCH"] = "True"
714

8-
__all__ = [
9-
"assert_close",
10-
"do_bench",
11-
"perf_report",
12-
"Benchmark",
13-
"BENCHMARKING_METHOD",
14-
]
15+
__all__ = ["assert_close", "do_bench", "perf_report", "Benchmark", "BENCHMARKING_METHOD", "filter_providers"]

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import os
55
from dataclasses import dataclass
6+
from typing import Optional
67

78
import torch
89
from torch.profiler import profile, ProfilerActivity, record_function
@@ -183,6 +184,19 @@ def assert_close(x_fn, y_fn, atol=None, rtol=None, err_msg=""):
183184
triton_assert_close(x_fn(), y_fn(), atol, rtol, err_msg)
184185

185186

187+
def filter_providers(
188+
supported_providers: dict[str, str],
189+
providers_filter: Optional[list[str]],
190+
) -> dict[str, str]:
191+
providers = {}
192+
if providers_filter:
193+
for provider_key, provider_label in supported_providers.items():
194+
if provider_key in providers_filter:
195+
providers[provider_key] = provider_label
196+
return providers
197+
return supported_providers
198+
199+
186200
def perf_report(benchmarks):
187201
"""
188202
Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value.

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 173 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import os
22
import contextlib
3+
from typing import Optional
34

45
import torch
56
from torch.profiler import record_function
67
import triton
78
import triton.language as tl
89

9-
import triton_kernels_benchmark as benchmark_suit
10+
import triton_kernels_benchmark as benchmark_suite
1011
from triton_kernels_benchmark import xetla_kernel
1112
import numpy as np
1213

@@ -482,7 +483,7 @@ def backward(ctx, do):
482483
# https://github.com/pytorch/pytorch/issues/144778 has more details.
483484
with record_function(
484485
'__profile_kernel_of_func_bwd_fa'
485-
) if benchmark_suit.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext():
486+
) if benchmark_suite.BENCHMARKING_METHOD == 'UPSTREAM_PYTORCH_PROFILER' else contextlib.nullcontext():
486487
q, k, v, o, M = ctx.saved_tensors
487488
assert do.is_contiguous()
488489
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
@@ -540,135 +541,182 @@ def check_close(f_val, f_ref, atol, rtol):
540541
print(f'Warning: {num_not_close}, out of {close.size} elements do not match ({num_perc:.2f}%) in XeTLA impl')
541542

542543

543-
@benchmark_suit.perf_report(
544-
benchmark_suit.Benchmark(
545-
# argument names to use as an x-axis for the plot
546-
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
547-
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
548-
for z in [1, 2, 4, 8, 16, 32]
549-
for (h, dhead) in [(16, 128), (32, 64)]
550-
for causal in [False, True]
551-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
552-
+ [[4, 48, 1024, 64, causal, mode]
553-
for causal in [False, True]
554-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
555-
line_arg='provider',
556-
# argument name whose value corresponds to a different line in the plot
557-
# possible values for `line_arg``
558-
line_vals=['triton', 'xetla'],
559-
# label name for the lines
560-
line_names=['Triton', 'XeTLA'],
561-
# line styles
562-
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
563-
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
564-
plot_name='attn-performance',
565-
# name for the plot. Used also as a file name for saving the plot.
566-
args={},
567-
))
568-
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
569-
assert MODE in ['fwd', 'bwd']
570-
dtype = torch.float16
571-
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
572-
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
573-
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
574-
sm_scale = 0.125
575-
if MODE == 'bwd':
576-
sm_scale = 1.3
577-
quantiles = [0.5, 0.0, 1.0]
578-
atol = 1e-1 if N_CTX == 16384 else 1e-2
579-
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
580-
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
581-
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
582-
if MODE == 'bwd':
583-
torch_o = torch_fn()
584-
torch_do = torch.randn_like(torch_o)
585-
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
586-
if provider == 'onednn':
587-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
588-
589-
elif provider == 'triton':
590-
triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale)
544+
def get_benchmark(
545+
providers_filter: Optional[list[str]] = None,
546+
fa_kernel_mode='fwd',
547+
xetla_assert_result=False,
548+
xetla_warn_mismatch=True,
549+
):
550+
"""
551+
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
552+
The benchmark can then be executed by calling the :code:`.run` method on the return value.
553+
"""
554+
555+
supported_providers = {
556+
'triton': 'Triton',
557+
'xetla': 'XeTLA',
558+
}
559+
providers = benchmark_suite.filter_providers(supported_providers, providers_filter)
560+
561+
@benchmark_suite.perf_report(
562+
benchmark_suite.Benchmark(
563+
# argument names to use as an x-axis for the plot
564+
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
565+
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
566+
for z in [1, 2, 4, 8, 16, 32]
567+
for (h, dhead) in [(16, 128), (32, 64)]
568+
for causal in [False, True]
569+
for mode in [fa_kernel_mode]] #
570+
+ [[4, 48, 1024, 64, causal, mode] for causal in [False, True] for mode in [fa_kernel_mode]],
571+
line_arg='provider',
572+
# argument name whose value corresponds to a different line in the plot
573+
# possible values for `line_arg``
574+
line_vals=list(providers.keys()),
575+
# label name for the lines
576+
line_names=list(providers.values()),
577+
# line styles
578+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
579+
ylabel=['GB/s', 'TFlops'], # label name for the y-axis
580+
plot_name='attn-performance',
581+
# name for the plot. Used also as a file name for saving the plot.
582+
args={},
583+
))
584+
# pylint: disable=too-many-branches
585+
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
586+
modes = ['fwd', 'bwd']
587+
if MODE not in modes:
588+
raise AssertionError(f'Unknown {MODE}, supported modes are {modes}')
589+
dtype = torch.float16
590+
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
591+
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
592+
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
593+
sm_scale = 0.125
591594
if MODE == 'bwd':
592-
triton_o = triton_fn()
593-
triton_do = torch.randn_like(triton_o)
594-
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
595-
if MODE == 'fwd':
596-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
597-
else:
598-
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
599-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
600-
601-
elif provider == 'xetla':
602-
xetla_fn = None
603-
if MODE == 'fwd':
604-
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
605-
func = getattr(xetla_kernel, module_name)
606-
out = torch.empty_like(q, device='xpu', dtype=dtype)
607-
size_score = Z * H * N_CTX * N_CTX
608-
size_attn_mask = Z * N_CTX * N_CTX
609-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
610-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
611-
size_ml = Z * H * N_CTX
612-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
613-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
614-
615-
def xetla_fwd_fn():
616-
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
617-
return out
618-
619-
xetla_fn = xetla_fwd_fn
620-
621-
def check_xetla_fwd_result():
622-
if os.getenv('XETLA_ASSERT_RESULT', '0') == '1':
623-
benchmark_suit.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
624-
elif os.getenv('XETLA_WARN_MISMATCH', '1') == '1':
625-
check_close(xetla_fn, torch_fn, atol, 1e-3)
626-
627-
check_xetla_fwd_result()
628-
595+
sm_scale = 1.3
596+
quantiles = [0.5, 0.0, 1.0]
597+
atol = 1e-1 if N_CTX == 16384 else 1e-2
598+
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
599+
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
600+
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
629601
if MODE == 'bwd':
630-
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
631-
func = getattr(xetla_kernel, module_name)
632-
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
633-
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
634-
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
635-
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
636-
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
637-
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
638-
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
639-
alpha = sm_scale
640-
dropout_prob = 0
641-
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
642-
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
643-
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
644-
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
645-
bias_strideB = -1
646-
bias_strideN = -1
647-
bias_strideF = -1
648-
attn_mask_padding = 0
649-
650-
def xetla_bwd_fn():
651-
func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, dropout_prob,
652-
grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, N_CTX, bias_strideB,
653-
bias_strideN, bias_strideF, attn_mask_padding)
654-
return out
655-
656-
xetla_fn = xetla_bwd_fn
657-
658-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
602+
torch_o = torch_fn()
603+
torch_do = torch.randn_like(torch_o)
604+
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
605+
606+
if provider == 'onednn':
607+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
608+
torch_fn,
609+
n_warmup=10,
610+
n_repeat=10,
611+
quantiles=quantiles,
612+
)
659613

660-
else:
661-
raise NotImplementedError(f'Unsupported provider {provider}')
614+
elif provider == 'triton':
615+
triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale)
616+
if MODE == 'bwd':
617+
triton_o = triton_fn()
618+
triton_do = torch.randn_like(triton_o)
619+
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
620+
if MODE == 'fwd':
621+
benchmark_suite.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
622+
else:
623+
benchmark_suite.assert_close(
624+
lambda: triton_o,
625+
lambda: torch_o,
626+
atol=1e-2,
627+
rtol=0,
628+
err_msg='triton to torch',
629+
)
630+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
631+
triton_fn,
632+
n_warmup=10,
633+
n_repeat=10,
634+
quantiles=quantiles,
635+
)
636+
637+
elif provider == 'xetla':
638+
xetla_fn = None
639+
if MODE == 'fwd':
640+
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
641+
func = getattr(xetla_kernel, module_name)
642+
out = torch.empty_like(q, device='xpu', dtype=dtype)
643+
size_score = Z * H * N_CTX * N_CTX
644+
size_attn_mask = Z * N_CTX * N_CTX
645+
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
646+
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
647+
size_ml = Z * H * N_CTX
648+
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
649+
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
650+
651+
def xetla_fwd_fn():
652+
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
653+
return out
654+
655+
xetla_fn = xetla_fwd_fn
656+
657+
def check_xetla_fwd_result():
658+
if xetla_assert_result:
659+
benchmark_suite.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
660+
elif xetla_warn_mismatch:
661+
check_close(xetla_fn, torch_fn, atol, 1e-3)
662+
663+
check_xetla_fwd_result()
664+
665+
if MODE == 'bwd':
666+
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
667+
func = getattr(xetla_kernel, module_name)
668+
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
669+
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
670+
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
671+
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
672+
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
673+
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
674+
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
675+
alpha = sm_scale
676+
dropout_prob = 0
677+
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
678+
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
679+
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
680+
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
681+
bias_strideB = -1
682+
bias_strideN = -1
683+
bias_strideF = -1
684+
attn_mask_padding = 0
685+
686+
def xetla_bwd_fn():
687+
func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, dropout_prob,
688+
grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, N_CTX, bias_strideB,
689+
bias_strideN, bias_strideF, attn_mask_padding)
690+
return out
691+
692+
xetla_fn = xetla_bwd_fn
693+
694+
_, min_ms, max_ms, mean, cv = benchmark_suite.do_bench(
695+
xetla_fn,
696+
n_warmup=10,
697+
n_repeat=10,
698+
quantiles=quantiles,
699+
)
700+
701+
else:
702+
raise NotImplementedError(f'Unsupported provider {provider}')
662703

663-
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
664-
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
704+
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
705+
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
665706

666-
if MODE == 'bwd':
667-
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
668-
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
707+
if MODE == 'bwd':
708+
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
709+
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
710+
711+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
669712

670-
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
713+
return benchmark
671714

672715

673716
if __name__ == '__main__':
674-
benchmark.run(show_plots=False, print_data=True)
717+
_benchmark = get_benchmark(
718+
fa_kernel_mode=os.getenv('FA_KERNEL_MODE', 'fwd'),
719+
xetla_assert_result=(os.getenv('XETLA_ASSERT_RESULT', '0') == '1'),
720+
xetla_warn_mismatch=(os.getenv('XETLA_WARN_MISMATCH', '1') == '1'),
721+
)
722+
_benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)