Skip to content

Commit 94b3473

Browse files
authored
Add more flex attention cases to benchmark. (#3928)
Add the flex attention shapes which is used by real model to benchmark for tracking performance. I commented out 4 cases for now for the reason: 1. There is not enough share local memory for the Triton kernel. - Append shapes of Deepseek-v3 (Nope) - Decode shapes of Deepseek-v3 (Nope) 2. Flex Attention doesn't support such kind of shapes: Error: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2. - Decode shapes of Qwen2-7B 3. Triton kernel block shapes must be power of 2. - Decode shapes of Phi3-mini-3.8B We will investigate the first issue on Triton side later. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 5b8ee6d commit 94b3473

File tree

2 files changed

+96
-81
lines changed

2 files changed

+96
-81
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,7 @@ jobs:
281281
python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS
282282
283283
source ../../scripts/capture-hw-details.sh
284-
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
285-
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-xetla-report.csv --benchmark flexAttnCausal --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
284+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
286285
287286
- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
288287
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 95 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
import os
44
from torch.nn.attention.flex_attention import (
55
create_block_mask,
6+
create_mask,
67
flex_attention,
78
)
89

910
import torch
1011
import torch.nn.functional as F
12+
1113
import triton_kernels_benchmark as benchmark_suit
12-
from triton_kernels_benchmark import xetla_kernel
1314

1415
torch._dynamo.config.recompile_limit = 100 # pylint: disable=protected-access
1516

1617
# Compile the flex_attention function
17-
flex_attention = torch.compile(flex_attention, dynamic=False)
18+
compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
1819

1920

2021
@lru_cache
@@ -27,112 +28,127 @@ def causal_mask(_, __, q_idx, kv_idx):
2728
return q_idx >= kv_idx
2829

2930

31+
throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1'
32+
batch_sizes = [16, 32, 64] if throughput_test else [1]
33+
34+
3035
# Kernel profiling for Backward mode is not working as expected:
3136
# For details: https://github.com/pytorch/pytorch/issues/144778
3237
@benchmark_suit.perf_report(
3338
benchmark_suit.Benchmark(
34-
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
35-
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
36-
for z in [1, 2, 4, 8, 16, 32]
37-
for (h, dhead) in [(16, 128), (32, 64)]
38-
for causal in [True]
39-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
40-
+ [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
41-
+ [[z, h, 1024, dhead, True, mode]
42-
for z in [1, 2, 4, 8, 16, 32, 64]
43-
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
44-
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
39+
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
40+
x_vals=
41+
# Multi-head attention. H_q equals H_kv
42+
# Prefill shapes of Phi3-mini-3.8B
43+
[[z, 32, 32, 1024, 1024, 96, 96, 'fwd'] for z in batch_sizes] +
44+
# Prefill shapes of Deepseek-v3
45+
[[z, 128, 128, 1024, 1024, 192, 128, 'fwd'] for z in batch_sizes] +
46+
# Append shapes of Phi3-mini-3.8B
47+
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, 'fwd'] for z in batch_sizes] +
48+
49+
# Multi-query attention. H_kv equals 1.
50+
# Append shapes of Deepseek-v3 (Nope)
51+
[
52+
# RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 133120, Hardware limit: 131072.
53+
# [z, 128, 1, 512, 1024 + 128 + 512, 64, 512, 'fwd'] for z in batch_sizes
54+
] +
55+
# Append shapes of Deepseek-v3 (Rope)
56+
[] +
57+
58+
# Grouped-query attention. H_q / H_kv > 1
59+
# Prefill shapes of Llama-3.1-8B
60+
[[z, 32, 8, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
61+
# Prefill shapes of Qwen2-7B
62+
[[z, 28, 4, 1024, 1024, 128, 128, 'fwd'] for z in batch_sizes] +
63+
# Append shapes of Llama-3.1-8B
64+
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
65+
# Append shapes of Qwen2-7B
66+
[[z, 28, 4, 512, 1024 + 128 + 512, 128, 128, 'fwd'] for z in batch_sizes] +
67+
68+
# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
69+
# Decode shapes of Llama-3.1-8B
70+
[[z, 32, 8, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes] +
71+
# Decode shapes of Phi3-mini-3.8B
72+
[
73+
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
74+
# ValueError: Shape element 2 must be a power of 2
75+
# [z, 32, 32, 1, 1024 + 64, 96, 96, 'fwd'] for z in batch_sizes
76+
] +
77+
# Decode shapes of Qwen2-7B
78+
[
79+
# torch._inductor.exc.InductorError: LoweringException: ValueError: Number of shared query heads sharing the same KV head must be power of 2.
80+
# [z, 28, 4, 1, 1024 + 64, 128, 128, 'fwd'] for z in batch_sizes
81+
] +
82+
# Decode shapes of Deepseek-v3 (Nope)
83+
[
84+
# RuntimeError: No valid triton configs. OutOfResources: out of resource: shared memory, Required: 264192, Hardware limit: 131072.
85+
# [z, 128, 1, 1, 1024, 64, 512, 'fwd'] for z in batch_sizes
86+
] +
87+
# Decode shapes of Deepseek-v3 (Rope)
88+
[],
4589
line_arg='provider',
46-
line_vals=['triton', 'xetla'],
47-
line_names=['Triton', 'XeTLA'],
90+
line_vals=['triton'],
91+
line_names=['Triton'],
4892
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
4993
ylabel=['GB/s', 'TFlops'],
5094
plot_name='flexAttnCausal-performance',
5195
args={},
5296
))
53-
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
54-
assert MODE in ['fwd', 'bwd']
55-
assert CAUSAL
97+
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
98+
assert MODE in ['fwd']
5699
dtype = torch.float16
57-
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
58-
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
59-
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
100+
q = torch.randn((Z, H_q, N_CTX_q, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
101+
k = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_qk), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
102+
v = torch.randn((Z, H_kv, N_CTX_kv, D_HEAD_v), device='xpu', dtype=dtype, requires_grad=MODE == 'bwd')
60103
sm_scale = 0.125
61104
if MODE == 'bwd':
62105
sm_scale = 1.3
63106

64107
quantiles = [0.5, 0.0, 1.0]
65108
if provider == 'triton':
66-
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
67-
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device)
68-
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, kernel_options=kernel_options
69-
)
109+
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
110+
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
111+
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
112+
not H_q == H_kv), kernel_options=kernel_options)
113+
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
70114
if MODE == 'bwd':
71115
triton_o = triton_fn()
72116
triton_do = torch.randn_like(triton_o)
73117
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
74-
torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to(
75-
torch.float32)
76-
if MODE == 'bwd':
77-
torch_o = torch_fn()
78-
torch_do = torch.randn_like(torch_o)
79-
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
80-
if MODE == 'fwd':
81-
atol = 1e-1 if N_CTX == 16384 else 1e-2
82-
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
83-
else:
84-
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
118+
119+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=1e-2, rtol=1e-3, err_msg='triton to torch')
85120
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
86121

87-
elif provider == 'xetla':
88-
xetla_fn = None
89-
if MODE == 'fwd':
90-
module_name = 'flash_attn_causal_True'.lower()
91-
func = getattr(xetla_kernel, module_name)
92-
out = torch.empty_like(q, device='xpu', dtype=dtype)
93-
size_score = Z * H * N_CTX * N_CTX
94-
size_attn_mask = Z * N_CTX * N_CTX
95-
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
96-
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
97-
size_ml = Z * H * N_CTX
98-
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
99-
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
100-
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
101-
if MODE == 'bwd':
102-
module_name = 'flash_attn_bwd_causal_True'.lower()
103-
func = getattr(xetla_kernel, module_name)
104-
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
105-
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
106-
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
107-
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
108-
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
109-
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
110-
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
111-
alpha = sm_scale
112-
dropout_prob = 0
113-
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
114-
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
115-
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
116-
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
117-
bias_strideB = -1
118-
bias_strideN = -1
119-
bias_strideF = -1
120-
attn_mask_padding = 0
121-
122-
xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
123-
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
124-
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
125-
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
122+
elif provider == 'onednn':
123+
# OneDNN only supports MHA.
124+
if H_q == H_kv:
125+
mask = create_mask(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device=q.device)
126+
xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
127+
if MODE == 'bwd':
128+
xformers_o = xformers_fn()
129+
xformers_do = torch.randn_like(xformers_o)
130+
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
131+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10,
132+
quantiles=quantiles)
133+
else:
134+
_, min_ms, max_ms, mean, cv = float('nan'), float('nan'), float('nan'), float('nan'), float('nan')
126135

127136
else:
128137
raise NotImplementedError(f'Unsupported provider {provider}')
129138

130-
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
131-
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
139+
qk_flops = H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * 2 # mul + add
140+
pv_flops = H_q * N_CTX_q * D_HEAD_v * N_CTX_kv * 2 # mul + add
141+
tflops = lambda mean: Z * (qk_flops + pv_flops) * (1e-12) / (mean * 1e-3)
142+
143+
q_elems = H_q * N_CTX_q * D_HEAD_qk
144+
k_elems = H_kv * N_CTX_kv * D_HEAD_qk
145+
v_elems = H_kv * N_CTX_kv * D_HEAD_v
146+
gbps = lambda mean: Z * (q_elems + k_elems + v_elems) * 2 * (1e-9) / (mean * 1e-3) # float16 2 bytes
132147

133148
if MODE == 'bwd':
134-
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
135-
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
149+
tflops = lambda mean: 2.5 * 2 * 2 * Z * H_q * N_CTX_q * N_CTX_kv * D_HEAD_qk * (1e-12) / (mean * 1e-3)
150+
gbps = lambda mean: 2.5 * Z * H_q * (N_CTX_q * D_HEAD_qk + N_CTX_kv * D_HEAD_qk) * 2 * 2 * (1e-9) / (mean * 1e-3
151+
)
136152

137153
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
138154

0 commit comments

Comments
 (0)