@@ -256,13 +256,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
256256
257257 elif provider == 'triton' :
258258 triton_fn = lambda : forward (q , k , v , CAUSAL , sm_scale )
259- if benchmark_suit .USE_IPEX_OPTION :
260- torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (
261- q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = CAUSAL , scale = sm_scale ).to (torch .float32 )
262- else :
263- # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
264- torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (
265- ), attn_mask = None , dropout_p = 0.0 , is_causal = CAUSAL , scale = sm_scale ).to (torch .float32 )
259+ torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (
260+ q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = CAUSAL , scale = sm_scale ).to (torch .float32 )
266261 atol = 1e-1 if N_CTX == 16384 else 1e-2
267262 benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
268263 _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles ,
0 commit comments