88
99import triton_kernels_benchmark as benchmark_suit
1010from triton_kernels_benchmark import xetla_kernel
11+ import numpy as np
1112
1213
1314# pylint: disable=unused-argument
@@ -526,6 +527,19 @@ def backward(ctx, do):
526527attention = _attention .apply
527528
528529
530+ def check_close (f_val , f_ref , atol , rtol ):
531+ x = f_val ()
532+ y = f_ref ()
533+ x = x .cpu ().detach ().numpy ()
534+ y = y .cpu ().detach ().numpy ()
535+ close = np .isclose (x , y , atol = atol , rtol = rtol )
536+ num_close = np .count_nonzero (close )
537+ num_not_close = close .size - num_close
538+ num_perc = num_not_close / close .size * 100
539+ if num_not_close != 0 :
540+ print (f'Warning: { num_not_close } , out of { close .size } elements do not match ({ num_perc :.2f} %) in XeTLA impl' )
541+
542+
529543@benchmark_suit .perf_report (
530544 benchmark_suit .Benchmark (
531545 # argument names to use as an x-axis for the plot
@@ -561,6 +575,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
561575 if MODE == 'bwd' :
562576 sm_scale = 1.3
563577 quantiles = [0.5 , 0.0 , 1.0 ]
578+ atol = 1e-1 if N_CTX == 16384 else 1e-2
564579 # FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
565580 torch_fn = lambda : torch .nn .functional .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (
566581 ), attn_mask = None , dropout_p = 0.0 , is_causal = CAUSAL , scale = sm_scale ).to (torch .float32 )
@@ -578,7 +593,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
578593 triton_do = torch .randn_like (triton_o )
579594 triton_fn = lambda : triton_o .backward (triton_do , retain_graph = True )
580595 if MODE == 'fwd' :
581- atol = 1e-1 if N_CTX == 16384 else 1e-2
582596 benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
583597 else :
584598 benchmark_suit .assert_close (lambda : triton_o , lambda : torch_o , atol = 1e-2 , rtol = 0 , err_msg = 'triton to torch' )
@@ -597,7 +611,21 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
597611 size_ml = Z * H * N_CTX
598612 m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
599613 l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
600- xetla_fn = lambda : func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
614+
615+ def xetla_fwd_fn ():
616+ func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
617+ return out
618+
619+ xetla_fn = xetla_fwd_fn
620+
621+ def check_xetla_fwd_result ():
622+ if os .getenv ('XETLA_ASSERT_RESULT' , '0' ) == '1' :
623+ benchmark_suit .assert_close (xetla_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'xetla to torch' )
624+ elif os .getenv ('XETLA_WARN_MISMATCH' , '1' ) == '1' :
625+ check_close (xetla_fn , torch_fn , atol , 1e-3 )
626+
627+ check_xetla_fwd_result ()
628+
601629 if MODE == 'bwd' :
602630 module_name = f'flash_attn_bwd_causal_{ CAUSAL } ' .lower ()
603631 func = getattr (xetla_kernel , module_name )
@@ -619,9 +647,14 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
619647 bias_strideF = - 1
620648 attn_mask_padding = 0
621649
622- xetla_fn = lambda : func (grad_out , q , k , v , bias , dropout , out , log_sumexp , workspace , grad_q_tmp , alpha ,
623- dropout_prob , grad_query , grad_key , grad_value , grad_bias , Z , H , D_HEAD , N_CTX ,
624- N_CTX , bias_strideB , bias_strideN , bias_strideF , attn_mask_padding )
650+ def xetla_bwd_fn ():
651+ func (grad_out , q , k , v , bias , dropout , out , log_sumexp , workspace , grad_q_tmp , alpha , dropout_prob ,
652+ grad_query , grad_key , grad_value , grad_bias , Z , H , D_HEAD , N_CTX , N_CTX , bias_strideB ,
653+ bias_strideN , bias_strideF , attn_mask_padding )
654+ return out
655+
656+ xetla_fn = xetla_bwd_fn
657+
625658 _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xetla_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
626659
627660 else :
0 commit comments