33import os
44from torch .nn .attention .flex_attention import (
55 create_block_mask ,
6- create_mask ,
76 flex_attention ,
87)
98
109import torch
11- import torch .nn .functional as F
1210import torch ._inductor
1311import torch ._inductor .lowering
1412import torch ._inductor .kernel
@@ -74,6 +72,7 @@ def causal_mask(_, __, q_idx, kv_idx):
7472throughput_test = os .getenv ('THROUGHPUT_TEST' , '0' ) == '1'
7573batch_size = int (os .getenv ('BATCH_SIZE' , '1' ))
7674batch_sizes = [16 , 32 , 64 ] if throughput_test else [batch_size ]
75+ fa_kernel_mode = os .getenv ('FA_KERNEL_MODE' , 'fwd' )
7776
7877
7978# Kernel profiling for Backward mode is not working as expected:
@@ -84,48 +83,48 @@ def causal_mask(_, __, q_idx, kv_idx):
8483 x_vals =
8584 # Multi-head attention. H_q equals H_kv
8685 # Prefill shapes of Phi3-mini-3.8B
87- [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , 'fwd' ] for z in batch_sizes ] +
86+ [[z , 32 , 32 , 1024 , 1024 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
8887 # Prefill shapes of Deepseek-v3
89- [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , 'fwd' ] for z in batch_sizes ] +
88+ [[z , 128 , 128 , 1024 , 1024 , 192 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
9089 # Append shapes of Phi3-mini-3.8B
91- [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , 'fwd' ] for z in batch_sizes ] +
90+ [[z , 32 , 32 , 512 , 1024 + 128 + 512 , 96 , 96 , fa_kernel_mode ] for z in batch_sizes ] +
9291
9392 # Multi-query attention. H_kv equals 1.
9493 # Append shapes of Deepseek-v3 (Nope)
95- [[z , 128 , 1 , 512 , 1024 + 128 + 512 , 64 , 512 , 'fwd' ] for z in batch_sizes ] +
94+ [[z , 128 , 1 , 512 , 1024 + 128 + 512 , 64 , 512 , fa_kernel_mode ] for z in batch_sizes ] +
9695 # Append shapes of Deepseek-v3 (Rope)
9796 [] +
9897
9998 # Grouped-query attention. H_q / H_kv > 1
10099 # Prefill shapes of Llama-3.1-8B
101- [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
100+ [[z , 32 , 8 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
102101 # Prefill shapes of Qwen2-7B
103- [[z , 28 , 4 , 1024 , 1024 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
102+ [[z , 28 , 4 , 1024 , 1024 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
104103 # Append shapes of Llama-3.1-8B
105- [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
104+ [[z , 32 , 8 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
106105 # Append shapes of Qwen2-7B
107- [[z , 28 , 4 , 512 , 1024 + 128 + 512 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
106+ [[z , 28 , 4 , 512 , 1024 + 128 + 512 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
108107
109108 # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
110109 # Decode shapes of Llama-3.1-8B
111- [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , 'fwd' ] for z in batch_sizes ] +
110+ [[z , 32 , 8 , 1 , 1024 + 64 , 128 , 128 , fa_kernel_mode ] for z in batch_sizes ] +
112111 # Decode shapes of Phi3-mini-3.8B
113112 [
114113 # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
115114 # ValueError: Shape element 2 must be a power of 2
116- # [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd' ] for z in batch_sizes
115+ # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode ] for z in batch_sizes
117116 ] +
118117 # Decode shapes of Qwen2-7B
119118 [
120119 # torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
121- # [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd' ] for z in batch_sizes
120+ # [z, 28, 4, 1, 1024 + 64, 128, 128, fa_kernel_mode ] for z in batch_sizes
122121 ] +
123122 # Decode shapes of Deepseek-v3 (Nope)
124123 [
125124 # There is an known issue in IGC for kernel with extreme register pressure.
126125 # Enable this case later with new IGC.
127126 # RuntimeError: ZE_RESULT_ERROR_INVALID_KERNEL_NAME
128- # [z, 128, 1, 1, 1024, 64, 512, 'fwd' ] for z in batch_sizes
127+ # [z, 128, 1, 1, 1024, 64, 512, fa_kernel_mode ] for z in batch_sizes
129128 ] +
130129 # Decode shapes of Deepseek-v3 (Rope)
131130 [],
@@ -138,52 +137,55 @@ def causal_mask(_, __, q_idx, kv_idx):
138137 args = {},
139138 ))
140139def benchmark (Z , H_q , H_kv , N_CTX_q , N_CTX_kv , D_HEAD_qk , D_HEAD_v , MODE , provider ):
141- assert MODE in ['fwd' ]
140+ if MODE not in ('fwd' , 'bwd' ):
141+ raise ValueError (f"Invalid MODE: { MODE } . Expected 'fwd' or 'bwd'." )
142142 dtype = torch .float16
143143 q = torch .randn ((Z , H_q , N_CTX_q , D_HEAD_qk ), device = DEVICE , dtype = dtype , requires_grad = MODE == 'bwd' )
144144 k = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_qk ), device = DEVICE , dtype = dtype , requires_grad = MODE == 'bwd' )
145145 v = torch .randn ((Z , H_kv , N_CTX_kv , D_HEAD_v ), device = DEVICE , dtype = dtype , requires_grad = MODE == 'bwd' )
146146 sm_scale = 0.125
147- if MODE == 'bwd' :
148- sm_scale = 1.3
149147
150148 quantiles = [0.5 , 0.0 , 1.0 ]
151149 block_mask = create_block_mask_cached (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = DEVICE )
152150 torch_fn = lambda : flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = not H_q == H_kv )
153151
154152 if provider == 'torch' :
155- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (torch_fn , n_warmup = 10 , n_repeat = 10 , quantiles = quantiles ,
156- device = DEVICE )
153+ if MODE == 'bwd' :
154+ min_ms = float ('nan' )
155+ max_ms = float ('nan' )
156+ mean = float ('nan' )
157+ cv = float ('nan' )
158+ else :
159+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (torch_fn , n_warmup = 10 , n_repeat = 10 ,
160+ quantiles = quantiles , device = DEVICE )
157161
158162 elif provider == 'triton' :
159163 kernel_options = {'BLOCKS_ARE_CONTIGUOUS' : True , 'USE_TMA' : True }
160164 triton_fn = lambda : compiled_flex_attention (q , k , v , block_mask = block_mask , scale = sm_scale , enable_gqa = (
161165 not H_q == H_kv ), kernel_options = kernel_options )
162166 if MODE == 'bwd' :
167+ torch_o = torch_fn ()
168+ backwards_grad = torch .randn_like (torch_o )
169+ torch_grads = torch .autograd .grad ((torch_o , ), (q , k , v ), backwards_grad , retain_graph = True )
170+ eager_tensors = (torch_o , * torch_grads )
163171 triton_o = triton_fn ()
164- triton_do = torch .randn_like ( triton_o )
165- triton_fn = lambda : triton_o . backward ( triton_do , retain_graph = True )
172+ triton_grads = torch .autograd . grad (( triton_o , ), ( q , k , v ), backwards_grad , retain_graph = True )
173+ compiled_tensors = ( triton_o , * triton_grads )
166174
167- benchmark_suit .assert_close (triton_fn , torch_fn , atol = 1e-2 , rtol = 1e-3 , err_msg = 'triton to torch' )
175+ tensor_names = ['out' , 'grad_query' , 'grad_key' , 'grad_value' ]
176+ for eager , compiled , name in zip (eager_tensors , compiled_tensors , tensor_names ):
177+ benchmark_suit .assert_close (lambda : eager , lambda : compiled , atol = 1e-2 , rtol = 1e-3 , # pylint: disable=cell-var-from-loop
178+ err_msg = f'Error comparing { name } between triton and torch' )
179+
180+ triton_fn = lambda : torch .autograd .grad ((triton_o , ), (q , k , v ), backwards_grad , retain_graph = True )
181+ else :
182+ benchmark_suit .assert_close (triton_fn , torch_fn , atol = 1e-2 , rtol = 1e-3 , err_msg = 'triton to torch' )
168183
169184 # Needs more warmup on B580 for some reason
170185 benchmark_suit .do_prewarmup (triton_fn )
171- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (triton_fn , n_warmup = 200 , n_repeat = 10 , quantiles = quantiles ,
172- device = DEVICE )
173-
174- elif provider == 'onednn' :
175- # OneDNN only supports MHA.
176- if H_q == H_kv :
177- mask = create_mask (causal_mask , 1 , 1 , N_CTX_q , N_CTX_kv , device = q .device )
178- xformers_fn = lambda : F .scaled_dot_product_attention (q , k , v , attn_mask = mask )
179- if MODE == 'bwd' :
180- xformers_o = xformers_fn ()
181- xformers_do = torch .randn_like (xformers_o )
182- xformers_fn = lambda : xformers_o .backward (xformers_do , retain_graph = True )
183- _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (xformers_fn , n_warmup = 10 , n_repeat = 10 ,
184- quantiles = quantiles )
185- else :
186- _ , min_ms , max_ms , mean , cv = float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' ), float ('nan' )
186+ _ , min_ms , max_ms , mean , cv = benchmark_suit .do_bench (
187+ triton_fn , n_warmup = 200 , n_repeat = 10 , quantiles = quantiles , device = DEVICE , grad_to_none = (q , k , v ),
188+ benchmark_label = None if MODE == 'fwd' else 'CompiledFunctionBackward' )
187189
188190 else :
189191 raise NotImplementedError (f'Unsupported provider { provider } ' )
@@ -198,9 +200,9 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
198200 gbps = lambda mean : Z * (q_elems + k_elems + v_elems ) * 2 * (1e-9 ) / (mean * 1e-3 ) # float16 2 bytes
199201
200202 if MODE == 'bwd' :
201- tflops = lambda mean : 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * ( 1e-12 ) / ( mean * 1e-3 )
202- gbps = lambda mean : 2.5 * Z * H_q * ( N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk ) * 2 * 2 * (1e-9 ) / (mean * 1e-3
203- )
203+ # The tflops and gbps are aligned to the one in flash_attention_benchmark.
204+ tflops = lambda mean : 2.5 * Z * ( qk_flops + pv_flops ) * (1e-12 ) / (mean * 1e-3 )
205+ gbps = lambda mean : 2.5 * Z * ( q_elems + k_elems + v_elems ) * 2 * ( 1e-9 ) / ( mean * 1e-3 )
204206
205207 return (gbps (mean ), gbps (max_ms ), gbps (min_ms )), (tflops (mean ), tflops (max_ms ), tflops (min_ms )), cv
206208
0 commit comments