@@ -226,11 +226,14 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):
226226 elif provider == 'triton' :
227227 triton_fn = lambda : forward (q , k , v , causal , sm_scale )
228228 if benchmark_suit .USE_IPEX_OPTION :
229- # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
230229 torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (
231230 q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = sm_scale ).to (torch .float32 )
232- atol = 1e-1 if N_CTX == 16384 else 1e-2
233- benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
231+ else :
232+ # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
233+ torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (
234+ ), attn_mask = None , dropout_p = 0.0 , is_causal = False , scale = sm_scale ).to (torch .float32 )
235+ atol = 1e-1 if N_CTX == 16384 else 1e-2
236+ benchmark_suit .assert_close (triton_fn (), torch_fn (), atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
234237 _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , warmup = 10 , rep = 10 , quantiles = quantiles ,
235238 fast_flush = False )
236239
0 commit comments