Skip to content

Commit e93c6b8

Browse files
authored
[BENCHMARK] Add the flex attn backward case in micro-benchmark. (#5057)
Add the flex attn backward case in micro-benchmark. --------- Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 80255de commit e93c6b8

File tree

2 files changed

+47
-43
lines changed

2 files changed

+47
-43
lines changed

benchmarks/triton_kernels_benchmark/benchmark_testing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def do_bench_elapsed_time(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quan
122122

123123

124124
def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_none=None, quantiles=None,
125-
return_mode="mean", device="xpu", sync_submitting=True):
125+
return_mode="mean", device="xpu", sync_submitting=True, benchmark_label=None):
126126
"""
127127
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
128128
the 20-th and 80-th performance percentile.
@@ -176,7 +176,9 @@ def do_bench_upstream_pytorch_profiler(fn, n_warmup=25, n_repeat=100, grad_to_no
176176
# Record clocks
177177
synchronize()
178178

179-
profiling_func_filter = filter(lambda x: x.name.startswith("__profile_kernel_of_func"), prof.events())
179+
profiling_func_filter = filter(
180+
lambda x: x.name.startswith("__profile_kernel_of_func" if benchmark_label is None else benchmark_label),
181+
prof.events())
180182
functions = list(profiling_func_filter)
181183

182184
def extract_kernels(funcs):

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
import os
44
from torch.nn.attention.flex_attention import (
55
create_block_mask,
6-
create_mask,
76
flex_attention,
87
)
98

109
import torch
11-
import torch.nn.functional as F
1210
import torch._inductor
1311
import torch._inductor.lowering
1412
import torch._inductor.kernel
@@ -74,6 +72,7 @@ def causal_mask(_, __, q_idx, kv_idx):
7472
throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1'
7573
batch_size = int(os.getenv('BATCH_SIZE', '1'))
7674
batch_sizes = [16, 32, 64] if throughput_test else [batch_size]
75+
fa_kernel_mode = os.getenv('FA_KERNEL_MODE', 'fwd')
7776

7877

7978
# Kernel profiling for Backward mode is not working as expected:
@@ -84,48 +83,48 @@ def causal_mask(_, __, q_idx, kv_idx):
8483
x_vals=
8584
# Multi-head attention. H_q equals H_kv
8685
# Prefill shapes of Phi3-mini-3.8B
87-
[[z, 32, 32, 1024, 1024, 96, 96, 'fwd'] for z in batch_sizes] +
86+
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
8887
# Prefill shapes of Deepseek-v3
89-
[[z, 128, 128, 1024, 1024, 192, 128, 'fwd'] for z in batch_sizes] +
88+
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
9089
# Append shapes of Phi3-mini-3.8B
91-
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, 'fwd'] for z in batch_sizes] +
90+
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +
9291
9392
# Multi-query attention. H_kv equals 1.
9493
# Append shapes of Deepseek-v3 (Nope)
95-
[[z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes] +
94+
[[z, 128, 1, 512, 1024 + 128 + 512, 64, 512, fa_kernel_mode] for z in batch_sizes] +
9695
# Append shapes of Deepseek-v3 (Rope)
9796
[] +
9897
9998
# Grouped-query attention. H_q / H_kv > 1
10099
# Prefill shapes of Llama-3.1-8B
101-
[[z, 32, 8, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
100+
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
102101
# Prefill shapes of Qwen2-7B
103-
[[z, 28, 4, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
102+
[[z, 28, 4, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
104103
# Append shapes of Llama-3.1-8B
105-
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
104+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
106105
# Append shapes of Qwen2-7B
107-
[[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
106+
[[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
108107
109108
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
110109
# Decode shapes of Llama-3.1-8B
111-
[[z, 32, 8, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes] +
110+
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
112111
# Decode shapes of Phi3-mini-3.8B
113112
[
114113
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
115114
# ValueError: Shape element 2 must be a power of 2
116-
# [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes
115+
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
117116
] +
118117
# Decode shapes of Qwen2-7B
119118
[
120119
# torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
121-
# [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes
120+
# [z, 28, 4, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes
122121
] +
123122
# Decode shapes of Deepseek-v3 (Nope)
124123
[
125124
# There is an known issue in IGC for kernel with extreme register pressure.
126125
# Enable this case later with new IGC.
127126
# RuntimeError: ZE_RESULT_ERROR_INVALID_KERNEL_NAME
128-
# [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes
127+
# [z, 128, 1, 1, 1024, 64, 512, fa_kernel_mode] for z in batch_sizes
129128
] +
130129
# Decode shapes of Deepseek-v3 (Rope)
131130
[],
@@ -138,52 +137,55 @@ def causal_mask(_, __, q_idx, kv_idx):
138137
args={},
139138
))
140139
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
141-
assert MODE in ['fwd']
140+
if MODE not in ('fwd', 'bwd'):
141+
raise ValueError(f"Invalid MODE: {MODE}. Expected 'fwd' or 'bwd'.")
142142
dtype = torch.float16
143143
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device=DEVICE, dtype=dtype, requires_grad=MODE == 'bwd')
144144
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device=DEVICE, dtype=dtype, requires_grad=MODE == 'bwd')
145145
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device=DEVICE, dtype=dtype, requires_grad=MODE == 'bwd')
146146
sm_scale = 0.125
147-
if MODE == 'bwd':
148-
sm_scale = 1.3
149147

150148
quantiles = [0.5, 0.0, 1.0]
151149
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=DEVICE)
152150
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
153151

154152
if provider == 'torch':
155-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles,
156-
device=DEVICE)
153+
if MODE == 'bwd':
154+
min_ms = float('nan')
155+
max_ms = float('nan')
156+
mean = float('nan')
157+
cv = float('nan')
158+
else:
159+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10,
160+
quantiles=quantiles, device=DEVICE)
157161

158162
elif provider == 'triton':
159163
kernel_options = {'BLOCKS_ARE_CONTIGUOUS': True, 'USE_TMA': True}
160164
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
161165
not H_q == H_kv), kernel_options=kernel_options)
162166
if MODE == 'bwd':
167+
torch_o = torch_fn()
168+
backwards_grad = torch.randn_like(torch_o)
169+
torch_grads = torch.autograd.grad((torch_o, ), (q, k, v), backwards_grad, retain_graph=True)
170+
eager_tensors = (torch_o, *torch_grads)
163171
triton_o = triton_fn()
164-
triton_do = torch.randn_like(triton_o)
165-
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
172+
triton_grads = torch.autograd.grad((triton_o, ), (q, k, v), backwards_grad, retain_graph=True)
173+
compiled_tensors = (triton_o, *triton_grads)
166174

167-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
175+
tensor_names = ['out', 'grad_query', 'grad_key', 'grad_value']
176+
for eager, compiled, name in zip(eager_tensors, compiled_tensors, tensor_names):
177+
benchmark_suit.assert_close(lambda: eager, lambda: compiled, atol=1e-2, rtol=1e-3, # pylint: disable=cell-var-from-loop
178+
err_msg=f'Error comparing {name} between triton and torch')
179+
180+
triton_fn = lambda: torch.autograd.grad((triton_o, ), (q, k, v), backwards_grad, retain_graph=True)
181+
else:
182+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
168183

169184
# Needs more warmup on B580 for some reason
170185
benchmark_suit.do_prewarmup(triton_fn)
171-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=200, n_repeat=10, quantiles=quantiles,
172-
device=DEVICE)
173-
174-
elif provider == 'onednn':
175-
# OneDNN only supports MHA.
176-
if H_q == H_kv:
177-
mask = create_mask(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=q.device)
178-
xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
179-
if MODE == 'bwd':
180-
xformers_o = xformers_fn()
181-
xformers_do = torch.randn_like(xformers_o)
182-
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
183-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10,
184-
quantiles=quantiles)
185-
else:
186-
_, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan')
186+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
187+
triton_fn, n_warmup=200, n_repeat=10, quantiles=quantiles, device=DEVICE, grad_to_none=(q, k, v),
188+
benchmark_label=None if MODE == 'fwd' else 'CompiledFunctionBackward')
187189

188190
else:
189191
raise NotImplementedError(f'Unsupported provider {provider}')
@@ -198,9 +200,9 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
198200
gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes
199201

200202
if MODE == 'bwd':
201-
tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12) / (mean * 1e-3)
202-
gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk) * 2 * 2 * (1e-9) / (mean * 1e-3
203-
)
203+
# The tflops and gbps are aligned to the one in flash_attention_benchmark.
204+
tflops = lambda mean: 2.5 * Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3)
205+
gbps = lambda mean: 2.5 * Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3)
204206

205207
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
206208

0 commit comments

Comments
 (0)