33import os
44from torch .nn .attention .flex_attention import (
55 create_block_mask ,
6+ create_mask ,
67 flex_attention ,
78)
89
910import torch
1011import torch .nn .functional as F
12+
1113import triton_kernels_benchmark as benchmark_suit
12- from triton_kernels_benchmark import xetla_kernel
1314
1415torch ._dynamo .config .recompile_limit = 100 # pylint: disable=protected-access
1516
1617# Compile the flex_attention function
17- flex_attention = torch .compile (flex_attention , dynamic = False )
18+ compiled_flex_attention = torch .compile (flex_attention , dynamic = False )
1819
1920
2021@lru_cache
@@ -27,112 +28,127 @@ def causal_mask(_, __, q_idx, kv_idx):
2728 return q_idx >= kv_idx
2829
2930
31+ throughput_test = os .getenv ('THROUGHPUT_TEST' , '0' ) == '1'
32+ batch_sizes = [16 , 32 , 64 ] if throughput_test else [1 ]
33+
34+
3035# Kernel profiling for Backward mode is not working as expected:
3136# For details: https://github.com/pytorch/pytorch/issues/144778
3237@benchmark_suit .perf_report (
3338 benchmark_suit .Benchmark (
34- x_names = ['Z' , 'H' , 'N_CTX' , 'D_HEAD' , 'CAUSAL' , 'MODE' ],
35- x_vals = [[z , h , 16384 // z , dhead , causal , mode ]
36- for z in [1 , 2 , 4 , 8 , 16 , 32 ]
37- for (h , dhead ) in [(16 , 128 ), (32 , 64 )]
38- for causal in [True ]
39- for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] #
40- + [[4 , 48 , 1024 , 64 , True , mode ] for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]] #
41- + [[z , h , 1024 , dhead , True , mode ]
42- for z in [1 , 2 , 4 , 8 , 16 , 32 , 64 ]
43- for (h , dhead ) in [(8 , 128 ), (32 , 96 ), (4 , 128 )]
44- for mode in [os .getenv ('FA_KERNEL_MODE' , 'fwd' )]],
39+ x_names = ['Z' , 'H_q' , 'H_kv' , 'N_CTX_q' , 'N_CTX_kv' , 'D_HEAD_qk' , 'D_HEAD_v' , 'MODE' ],
40+ x_vals =
41+ # Multi-head attention. H_q equals H_kv
42+ # Prefill shapes of Phi3-mini-3.8B
43+ [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , 'fwd' ] for z in batch_sizes ] +
44+ # Prefill shapes of Deepseek-v3
45+ [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , 'fwd' ] for z in batch_sizes ] +
46+ # Append shapes of Phi3-mini-3.8B
47+ [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , 'fwd' ] for z in batch_sizes ] +
48+
49+ # Multi-query attention. H_kv equals 1.
50+ # Append shapes of Deepseek-v3 (Nope)
51+ [
52+ # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072.
53+ # [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes
54+ ] +
55+ # Append shapes of Deepseek-v3 (Rope)
56+ [] +
57+
58+ # Grouped-query attention. H_q / H_kv > 1
59+ # Prefill shapes of Llama-3.1-8B
60+ [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
61+ # Prefill shapes of Qwen2-7B
62+ [[z , 28 , 4 , 1024 , 1024 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
63+ # Append shapes of Llama-3.1-8B
64+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
65+ # Append shapes of Qwen2-7B
66+ [[z , 28 , 4 , 512 , 1024 + 128 + 512 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
67+
68+ # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
69+ # Decode shapes of Llama-3.1-8B
70+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
71+ # Decode shapes of Phi3-mini-3.8B
72+ [
73+ # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
74+ # ValueError: Shape element 2 must be a power of 2
75+ # [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes
76+ ] +
77+ # Decode shapes of Qwen2-7B
78+ [
79+ # torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
80+ # [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes
81+ ] +
82+ # Decode shapes of Deepseek-v3 (Nope)
83+ [
84+ # RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072.
85+ # [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes
86+ ] +
87+ # Decode shapes of Deepseek-v3 (Rope)
88+ [],
4589 line_arg = 'provider' ,
46- line_vals = ['triton' , 'xetla' ],
47- line_names = ['Triton' , 'XeTLA' ],
90+ line_vals = ['triton' ],
91+ line_names = ['Triton' ],
4892 styles = [('green' , '-' ), ('green' , '--' ), ('blue' , '-' ), ('blue' , '--' )],
4993 ylabel = ['GB/s' , 'TFlops' ],
5094 plot_name = 'flexAttnCausal-performance' ,
5195 args = {},
5296 ))
53- def benchmark (Z , H , N_CTX , D_HEAD , CAUSAL , MODE , provider ):
54- assert MODE in ['fwd' , 'bwd' ]
55- assert CAUSAL
97+ def benchmark (Z , H_q , H_kv , N_CTX_q , N_CTX_kv , D_HEAD_qk , D_HEAD_v , MODE , provider ):
98+ assert MODE in ['fwd' ]
5699 dtype = torch .float16
57- q = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
58- k = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
59- v = torch .randn ((Z , H , N_CTX , D_HEAD ), device = 'xpu' , dtype = dtype , requires_grad = True )
100+ q = torch .randn ((Z , H_q , N_CTX_q , D_HEAD_qk ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
101+ k = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_qk ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
102+ v = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_v ), device = 'xpu' , dtype = dtype , requires_grad = MODE == 'bwd' )
60103 sm_scale = 0.125
61104 if MODE == 'bwd' :
62105 sm_scale = 1.3
63106
64107 quantiles = [0.5 , 0.0 , 1.0 ]
65108 if provider == 'triton' :
66- kernel_options = {'num_stages' : 2 , 'num_warps' : 16 if D_HEAD == 128 else 8 , 'BLOCKS_ARE_CONTIGUOUS' : True }
67- block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX , N_CTX , device = q .device )
68- triton_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , kernel_options = kernel_options
69- )
109+ kernel_options = {'num_stages' : 2 , 'num_warps' : 16 if D_HEAD_qk == 128 else 8 , 'BLOCKS_ARE_CONTIGUOUS' : True }
110+ block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = 'xpu' )
111+ triton_fn = lambda : compiled_flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = (
112+ not H_q == H_kv ), kernel_options = kernel_options )
113+ torch_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = not H_q == H_kv )
70114 if MODE == 'bwd' :
71115 triton_o = triton_fn ()
72116 triton_do = torch .randn_like (triton_o )
73117 triton_fn = lambda : triton_o .backward (triton_do , retain_graph = True )
74- torch_fn = lambda : F .scaled_dot_product_attention (q .cpu (), k .cpu (), v .cpu (), is_causal = True , scale = sm_scale ).to (
75- torch .float32 )
76- if MODE == 'bwd' :
77- torch_o = torch_fn ()
78- torch_do = torch .randn_like (torch_o )
79- torch_fn = lambda : torch_o .backward (torch_do , retain_graph = True )
80- if MODE == 'fwd' :
81- atol = 1e-1 if N_CTX == 16384 else 1e-2
82- benchmark_suit .assert_close (triton_fn , torch_fn , atol = atol , rtol = 1e-3 , err_msg = 'triton to torch' )
83- else :
84- benchmark_suit .assert_close (lambda : triton_o , lambda : torch_o , atol = 1e-2 , rtol = 0 , err_msg = 'triton to torch' )
118+
119+ benchmark_suit .assert_close (triton_fn , torch_fn , atol = 1e-2 , rtol = 1e-3 , err_msg = 'triton to torch' )
85120 _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
86121
87- elif provider == 'xetla' :
88- xetla_fn = None
89- if MODE == 'fwd' :
90- module_name = 'flash_attn_causal_True' .lower ()
91- func = getattr (xetla_kernel , module_name )
92- out = torch .empty_like (q , device = 'xpu' , dtype = dtype )
93- size_score = Z * H * N_CTX * N_CTX
94- size_attn_mask = Z * N_CTX * N_CTX
95- dropout_mask = torch .empty ((size_score , ), device = 'xpu' , dtype = torch .uint8 )
96- bias = torch .empty ((size_attn_mask , ), device = 'xpu' , dtype = dtype )
97- size_ml = Z * H * N_CTX
98- m = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
99- l = torch .empty ((size_ml , ), device = 'xpu' , dtype = torch .float )
100- xetla_fn = lambda : func (q , k , v , out , dropout_mask , bias , m , l , Z , H , D_HEAD , N_CTX , N_CTX , sm_scale )
101- if MODE == 'bwd' :
102- module_name = 'flash_attn_bwd_causal_True' .lower ()
103- func = getattr (xetla_kernel , module_name )
104- grad_out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
105- bias = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
106- dropout = torch .empty_like (q , device = 'xpu' , dtype = torch .uint8 )
107- out = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
108- log_sumexp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
109- workspace = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
110- grad_q_tmp = torch .zeros (q .size (), device = 'xpu' , dtype = dtype , requires_grad = True )
111- alpha = sm_scale
112- dropout_prob = 0
113- grad_query = torch .empty_like (q , device = 'xpu' , dtype = dtype , requires_grad = True )
114- grad_key = torch .empty_like (k , device = 'xpu' , dtype = dtype , requires_grad = True )
115- grad_value = torch .empty_like (v , device = 'xpu' , dtype = dtype , requires_grad = True )
116- grad_bias = torch .empty_like (bias , device = 'xpu' , dtype = dtype , requires_grad = True )
117- bias_strideB = - 1
118- bias_strideN = - 1
119- bias_strideF = - 1
120- attn_mask_padding = 0
121-
122- xetla_fn = lambda : func (grad_out , q , k , v , bias , dropout , out , log_sumexp , workspace , grad_q_tmp , alpha ,
123- dropout_prob , grad_query , grad_key , grad_value , grad_bias , Z , H , D_HEAD , N_CTX ,
124- N_CTX , bias_strideB , bias_strideN , bias_strideF , attn_mask_padding )
125- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xetla_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles )
122+ elif provider == 'onednn' :
123+ # OneDNN only supports MHA.
124+ if H_q == H_kv :
125+ mask = create_mask (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = q .device )
126+ xformers_fn = lambda : F .scaled_dot_product_attention (q , k , v , attn_mask = mask )
127+ if MODE == 'bwd' :
128+ xformers_o = xformers_fn ()
129+ xformers_do = torch .randn_like (xformers_o )
130+ xformers_fn = lambda : xformers_o .backward (xformers_do , retain_graph = True )
131+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xformers_fn , n_warmup = 10 , n_repeat = 10 ,
132+ quantiles = quantiles )
133+ else :
134+ _ , min_ms , max_ms , mean , cv = float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' )
126135
127136 else :
128137 raise NotImplementedError (f'Unsupported provider { provider } ' )
129138
130- tflops = lambda mean : 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12 ) / (mean * 1e-3 )
131- gbps = lambda mean : Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
139+ qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * 2 # mul + add
140+ pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv * 2 # mul + add
141+ tflops = lambda mean : Z * (qk_flops + pv_flops ) * (1e-12 ) / (mean * 1e-3 )
142+
143+ q_elems = H_q * N_CTX_q * D_HEAD_qk
144+ k_elems = H_kv * N_CTX_kv * D_HEAD_qk
145+ v_elems = H_kv * N_CTX_kv * D_HEAD_v
146+ gbps = lambda mean : Z * (q_elems + k_elems + v_elems ) * 2 * (1e-9 ) / (mean * 1e-3 ) # float16 2 bytes
132147
133148 if MODE == 'bwd' :
134- tflops = lambda mean : 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12 ) / (mean * 1e-3 )
135- gbps = lambda mean : 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD ) * 2 * 2 * (1e-9 ) / (mean * 1e-3 )
149+ tflops = lambda mean : 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12 ) / (mean * 1e-3 )
150+ gbps = lambda mean : 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk ) * 2 * 2 * (1e-9 ) / (mean * 1e-3
151+ )
136152
137153 return (gbps (mean ), gbps (max_ms ), gbps (min_ms )), (tflops (mean ), tflops (max_ms ), tflops (min_ms )), cv
138154
0 commit comments