Skip to content

Commit fed3b87

Browse files
[FA] Fix XPU out of memory on oneDNN (#3773)
Use the same workaround introduced in 24d985b. --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 7dd7963 commit fed3b87

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -561,25 +561,22 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
561561
if MODE == 'bwd':
562562
sm_scale = 1.3
563563
quantiles = [0.5, 0.0, 1.0]
564+
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
565+
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
566+
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
567+
if MODE == 'bwd':
568+
torch_o = torch_fn()
569+
torch_do = torch.randn_like(torch_o)
570+
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
564571
if provider == 'onednn':
565-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
566-
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
567-
CAUSAL, scale=sm_scale), n_warmup=10, n_repeat=10,
568-
quantiles=quantiles)
572+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
569573

570574
elif provider == 'triton':
571575
triton_fn = lambda: attention(q, k, v, CAUSAL, sm_scale)
572576
if MODE == 'bwd':
573577
triton_o = triton_fn()
574578
triton_do = torch.randn_like(triton_o)
575579
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
576-
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
577-
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
578-
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
579-
if MODE == 'bwd':
580-
torch_o = torch_fn()
581-
torch_do = torch.randn_like(torch_o)
582-
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
583580
if MODE == 'fwd':
584581
atol = 1e-1 if N_CTX == 16384 else 1e-2
585582
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')

0 commit comments

Comments
 (0)