Skip to content

Commit 4415428

Browse files
committed
Remove workaround for 'torch.nn.functional.scaled_dot_product_attention'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 81b0627 commit 4415428

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
256256

257257
elif provider == 'triton':
258258
triton_fn = lambda: forward(q, k, v, CAUSAL, sm_scale)
259-
if benchmark_suit.USE_IPEX_OPTION:
260-
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
261-
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
262-
else:
263-
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
264-
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(
265-
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
259+
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
260+
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
266261
atol = 1e-1 if N_CTX == 16384 else 1e-2
267262
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
268263
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,

0 commit comments

Comments
 (0)