Skip to content

Commit 6cbdd42

Browse files
authored
Merge branch 'main' into fix/try-disabling-fp64-patch
2 parents aca5140 + 752de28 commit 6cbdd42

File tree

1 file changed

+38
-5
lines changed

1 file changed

+38
-5
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import triton_kernels_benchmark as benchmark_suit
1010
from triton_kernels_benchmark import xetla_kernel
11+
import numpy as np
1112

1213

1314
# pylint: disable=unused-argument
@@ -526,6 +527,19 @@ def backward(ctx, do):
526527
attention = _attention.apply
527528

528529

530+
def check_close(f_val, f_ref, atol, rtol):
531+
x = f_val()
532+
y = f_ref()
533+
x = x.cpu().detach().numpy()
534+
y = y.cpu().detach().numpy()
535+
close = np.isclose(x, y, atol=atol, rtol=rtol)
536+
num_close = np.count_nonzero(close)
537+
num_not_close = close.size - num_close
538+
num_perc = num_not_close / close.size * 100
539+
if num_not_close != 0:
540+
print(f'Warning: {num_not_close}, out of {close.size} elements do not match ({num_perc:.2f}%) in XeTLA impl')
541+
542+
529543
@benchmark_suit.perf_report(
530544
benchmark_suit.Benchmark(
531545
# argument names to use as an x-axis for the plot
@@ -561,6 +575,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
561575
if MODE == 'bwd':
562576
sm_scale = 1.3
563577
quantiles = [0.5, 0.0, 1.0]
578+
atol = 1e-1 if N_CTX == 16384 else 1e-2
564579
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
565580
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
566581
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
@@ -578,7 +593,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
578593
triton_do = torch.randn_like(triton_o)
579594
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
580595
if MODE == 'fwd':
581-
atol = 1e-1 if N_CTX == 16384 else 1e-2
582596
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
583597
else:
584598
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
@@ -597,7 +611,21 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
597611
size_ml = Z * H * N_CTX
598612
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
599613
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
600-
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
614+
615+
def xetla_fwd_fn():
616+
func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
617+
return out
618+
619+
xetla_fn = xetla_fwd_fn
620+
621+
def check_xetla_fwd_result():
622+
if os.getenv('XETLA_ASSERT_RESULT', '0') == '1':
623+
benchmark_suit.assert_close(xetla_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='xetla to torch')
624+
elif os.getenv('XETLA_WARN_MISMATCH', '1') == '1':
625+
check_close(xetla_fn, torch_fn, atol, 1e-3)
626+
627+
check_xetla_fwd_result()
628+
601629
if MODE == 'bwd':
602630
module_name = f'flash_attn_bwd_causal_{CAUSAL}'.lower()
603631
func = getattr(xetla_kernel, module_name)
@@ -619,9 +647,14 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
619647
bias_strideF = -1
620648
attn_mask_padding = 0
621649

622-
xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
623-
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
624-
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
650+
def xetla_bwd_fn():
651+
func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha, dropout_prob,
652+
grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX, N_CTX, bias_strideB,
653+
bias_strideN, bias_strideF, attn_mask_padding)
654+
return out
655+
656+
xetla_fn = xetla_bwd_fn
657+
625658
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
626659

627660
else:

0 commit comments

Comments
 (0)