10
10
import triton_kernels_benchmark as benchmark_suite
11
11
from triton_kernels_benchmark import xetla_kernel
12
12
from triton_kernels_benchmark import cutlass_kernel
13
- import numpy as np
14
13
15
14
16
15
# pylint: disable=unused-argument
@@ -529,25 +528,10 @@ def backward(ctx, do):
529
528
attention = _attention .apply
530
529
531
530
532
- def check_close (f_val , f_ref , atol , rtol ):
533
- x = f_val ()
534
- y = f_ref ()
535
- x = x .cpu ().detach ().numpy ()
536
- y = y .cpu ().detach ().numpy ()
537
- close = np .isclose (x , y , atol = atol , rtol = rtol )
538
- num_close = np .count_nonzero (close )
539
- num_not_close = close .size - num_close
540
- num_perc = num_not_close / close .size * 100
541
- if num_not_close != 0 :
542
- print (f'Warning: { num_not_close } , out of { close .size } elements do not match ({ num_perc :.2f} %) in XeTLA impl' )
543
-
544
-
545
531
def get_benchmark (
546
532
providers_filter : Optional [list [str ]] = None ,
547
533
fa_kernel_mode = 'fwd' ,
548
534
attn_fwd = _attn_fwd_with_block_pointers ,
549
- xetla_assert_result = False ,
550
- xetla_warn_mismatch = False ,
551
535
):
552
536
"""
553
537
Returns a Mark object containing a Benchmark object constructed at runtime and parameterized by the provided option values.
@@ -647,33 +631,6 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
647
631
)
648
632
649
633
elif provider == 'xetla' :
650
- xetla_fn = None
651
- if MODE == 'fwd' :
652
- module_name = f'flash_attn_causal_{ CAUSAL } ' .lower ()
653
- func = getattr (xetla_kernel , module_name )
654
- out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
655
- size_score = Z * H * N_CTX * N_CTX
656
- size_attn_mask = Z * N_CTX * N_CTX
657
- dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
658
- bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
659
- size_ml = Z * H * N_CTX
660
- m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
661
- l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
662
-
663
- def xetla_fwd_fn ():
664
- func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
665
- return out
666
-
667
- xetla_fn = xetla_fwd_fn
668
-
669
- def check_xetla_fwd_result ():
670
- if xetla_assert_result :
671
- benchmark_suite .assert_close (xetla_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'xetla to torch' )
672
- elif xetla_warn_mismatch :
673
- check_close (xetla_fn , torch_fn , atol , 1e-3 )
674
-
675
- check_xetla_fwd_result ()
676
-
677
634
if MODE == 'bwd' :
678
635
module_name = f'flash_attn_bwd_causal_{ CAUSAL } ' .lower ()
679
636
func = getattr (xetla_kernel , module_name )
@@ -701,18 +658,20 @@ def xetla_bwd_fn():
701
658
bias_strideN , bias_strideF , attn_mask_padding )
702
659
return out
703
660
704
- xetla_fn = xetla_bwd_fn
661
+ _ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
662
+ xetla_bwd_fn ,
663
+ n_warmup = 10 ,
664
+ n_repeat = 10 ,
665
+ quantiles = quantiles ,
666
+ )
705
667
706
- _ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
707
- xetla_fn ,
708
- n_warmup = 10 ,
709
- n_repeat = 10 ,
710
- quantiles = quantiles ,
711
- )
668
+ else :
669
+ min_ms = float ('nan' )
670
+ max_ms = float ('nan' )
671
+ mean = float ('nan' )
672
+ cv = float ('nan' )
712
673
713
674
elif provider == 'cutlass' :
714
- cutlass_fn = None
715
-
716
675
if MODE == 'fwd' :
717
676
name = 'attention'
718
677
func = getattr (cutlass_kernel , name )
@@ -723,17 +682,15 @@ def cutlass_fwd_fn():
723
682
return out
724
683
725
684
benchmark_suite .assert_close (cutlass_fwd_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'cutlass to torch' )
726
- cutlass_fn = cutlass_fwd_fn
727
685
728
686
_ , min_ms , max_ms , mean , cv = benchmark_suite .do_bench (
729
- cutlass_fn ,
687
+ cutlass_fwd_fn ,
730
688
n_warmup = 10 ,
731
689
n_repeat = 10 ,
732
690
quantiles = quantiles ,
733
691
)
734
692
735
693
else :
736
- cutlass_fn = None
737
694
min_ms = float ('nan' )
738
695
max_ms = float ('nan' )
739
696
mean = float ('nan' )
@@ -755,9 +712,5 @@ def cutlass_fwd_fn():
755
712
756
713
757
714
if __name__ == '__main__' :
758
- _benchmark = get_benchmark (
759
- fa_kernel_mode = os .getenv ('FA_KERNEL_MODE' , 'fwd' ),
760
- xetla_assert_result = (os .getenv ('XETLA_ASSERT_RESULT' , '0' ) == '1' ),
761
- xetla_warn_mismatch = (os .getenv ('XETLA_WARN_MISMATCH' , '0' ) == '1' ),
762
- )
715
+ _benchmark = get_benchmark (fa_kernel_mode = os .getenv ('FA_KERNEL_MODE' , 'fwd' ), )
763
716
_benchmark .run (show_plots = False , print_data = True )
0 commit comments