Skip to content

Commit 24d985b

Browse files
authored
Use cpu version of torch sdpa until xpu version is fixed (#2300)
CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/10958477930 (passed) --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 74dc4a2 commit 24d985b

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,14 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
226226
elif provider == 'triton':
227227
triton_fn = lambda: forward(q, k, v, causal, sm_scale)
228228
if benchmark_suit.USE_IPEX_OPTION:
229-
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
230229
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
231230
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
232-
atol = 1e-1 if N_CTX == 16384 else 1e-2
233-
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
231+
else:
232+
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
233+
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
234+
), attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
235+
atol = 1e-1 if N_CTX == 16384 else 1e-2
236+
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
234237
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
235238
fast_flush=False)
236239

0 commit comments

Comments
 (0)