@@ -561,25 +561,22 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
561561 if MODE == 'bwd' :
562562 sm_scale = 1.3
563563 quantiles = [0.5 , 0.0 , 1.0 ]
564+ # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
565+ torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (
566+ ), attn_mask = None , dropout_p = 0.0 , is_causal = CAUSAL , scale = sm_scale ).to (torch .float32 )
567+ if MODE == 'bwd' :
568+ torch_o = torch_fn ()
569+ torch_do = torch .randn_like (torch_o )
570+ torch_fn = lambda : torch_o .backward (torch_do , retain_graph = True )
564571 if provider == 'onednn' :
565- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (
566- lambda : torch .nn .functional .scaled_dot_product_attention (q , k , v , attn_mask = None , dropout_p = 0.0 , is_causal =
567- CAUSAL , scale = sm_scale ), n_warmup = 10 , n_repeat = 10 ,
568- quantiles = quantiles )
572+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (torch_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
569573
570574 elif provider == 'triton' :
571575 triton_fn = lambda : attention (q , k , v , CAUSAL , sm_scale )
572576 if MODE == 'bwd' :
573577 triton_o = triton_fn ()
574578 triton_do = torch .randn_like (triton_o )
575579 triton_fn = lambda : triton_o .backward (triton_do , retain_graph = True )
576- # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
577- torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (
578- ), attn_mask = None , dropout_p = 0.0 , is_causal = CAUSAL , scale = sm_scale ).to (torch .float32 )
579- if MODE == 'bwd' :
580- torch_o = torch_fn ()
581- torch_do = torch .randn_like (torch_o )
582- torch_fn = lambda : torch_o .backward (torch_do , retain_graph = True )
583580 if MODE == 'fwd' :
584581 atol = 1e-1 if N_CTX == 16384 else 1e-2
585582 benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
0 commit comments