Skip to content

Commit c874647

Browse files
Maxime France-Pilloispbchekin
andauthored
[FlexAttention] Add initial benchmarks (#3578)
Add benchmarks to evaluate flex attention kernels performances. Add these benchmarks to CI workflow (need to install a specific pytorch version with XPU FlexAttention support enabled). --------- Co-authored-by: Pavel Chekin <[email protected]>
1 parent 506db50 commit c874647

File tree

7 files changed

+327
-16
lines changed

7 files changed

+327
-16
lines changed

.github/actions/load/action.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ runs:
3737
ITEM_PATH="${{ inputs.root }}/${{ inputs.key }}"
3838
echo "dest=$ITEM_PATH" >> $GITHUB_OUTPUT
3939
if [[ -d ${{ inputs.path }} ]]; then
40-
echo "Directory ${{ inputs.path }} exists and will not be restored from cache"
41-
exit 1
40+
echo "Directory ${{ inputs.path }} already exists and will be removed"
41+
rm -rf ${{ inputs.path }}
4242
fi
4343
if [[ ${{ inputs.enabled == 'true' }} && -d $ITEM_PATH ]]; then
4444
echo "Cache hit for ${{ inputs.key }}"

.github/actions/setup-pytorch/action.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,14 @@ runs:
4545
if: inputs.ref != ''
4646
shell: bash
4747
run: |
48-
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
49-
echo "PYTORCH_COMMIT_ID=${{ steps.commit-id.outputs.commit_id }}" | tee -a "$GITHUB_ENV"
48+
if [[ "${{ inputs.repository }}" = "liangan1/pytorch" ]]; then
49+
PYTORCH_COMMIT_ID="$(<.github/pins/pytorchFlexAttention.txt)"
50+
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
51+
echo "PYTORCH_COMMIT_ID=$PYTORCH_COMMIT_ID" | tee -a "$GITHUB_ENV"
52+
else
53+
echo "PYTORCH_REPO=${{ inputs.repository }}" | tee -a "$GITHUB_ENV"
54+
echo "PYTORCH_COMMIT_ID=${{ steps.commit-id.outputs.commit_id }}" | tee -a "$GITHUB_ENV"
55+
fi
5056
5157
- name: Identify Python version
5258
shell: bash
@@ -99,7 +105,7 @@ runs:
99105
path: pytorch
100106

101107
- name: Apply additional PR patches
102-
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.repository == 'pytorch/pytorch' && inputs.mode == 'source' }}
108+
if: ${{ steps.pytorch-cache.outputs.status == 'miss' && inputs.mode == 'source' && (inputs.repository == 'pytorch/pytorch' || inputs.repository == 'liangan1/pytorch') }}
103109
shell: bash
104110
run: |
105111
cd pytorch
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bbc1fc47e716e7e6d195f8a84de7f7f286836028

.github/workflows/triton-benchmarks.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,34 @@ jobs:
282282
cd benchmarks/micro_benchmarks
283283
python run_benchmarks.py --reports $REPORTS
284284
285+
# Install Pytorch with FlexAttention XPU support enabled
286+
- name: Setup PyTorch
287+
uses: ./.github/actions/setup-pytorch
288+
with:
289+
repository: liangan1/pytorch
290+
ref: liangan1/flex_attention
291+
292+
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark
293+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py') }}
294+
run: |
295+
cd benchmarks/triton_kernels_benchmark
296+
python flex_attention_benchmark_causal_mask.py --reports $REPORTS
297+
298+
source ../../scripts/capture-hw-details.sh
299+
python ../../scripts/build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
300+
python ../../scripts/build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-xetla-report.csv --benchmark flexAttnCausal --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
301+
302+
- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
303+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}
304+
run: |
305+
cd benchmarks/triton_kernels_benchmark
306+
python flex_attention_benchmark_custom_masks.py --reports $REPORTS
307+
308+
source ../../scripts/capture-hw-details.sh
309+
python ../../scripts/build_report.py $REPORTS/flexAttnMasks-performance.csv $REPORTS/flexAttnMasks-triton-report.csv --benchmark flexAttnMasks --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,MASK" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG --mask
310+
python ../../scripts/build_report.py $REPORTS/flexAttnMasks-performance.csv $REPORTS/flexAttnMasks-onednn-report.csv --benchmark flexAttnMasks --compiler onednn --param_cols "Z,H,N_CTX,D_HEAD,MASK" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG --mask
311+
312+
285313
- name: Upload benchmark reports
286314
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
287315
uses: actions/upload-artifact@v4
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# This benchmark requires a Pytorch version with FlexAttention support for XPU available
2+
from functools import lru_cache
3+
import os
4+
from torch.nn.attention.flex_attention import (
5+
create_block_mask,
6+
flex_attention,
7+
)
8+
9+
import torch
10+
import torch.nn.functional as F
11+
import triton_kernels_benchmark as benchmark_suit
12+
from triton_kernels_benchmark import xetla_kernel
13+
14+
# Compile the flex_attention function
15+
flex_attention = torch.compile(flex_attention, dynamic=False)
16+
17+
18+
@lru_cache
19+
def create_block_mask_cached(score_mod, B, H, M, N, device='xpu'):
20+
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
21+
return block_mask
22+
23+
24+
def causal_mask(_, __, q_idx, kv_idx):
25+
return q_idx >= kv_idx
26+
27+
28+
# Kernel profiling for Backward mode is not working as expected:
29+
# For details: https://github.com/pytorch/pytorch/issues/144778
30+
@benchmark_suit.perf_report(
31+
benchmark_suit.Benchmark(
32+
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'CAUSAL', 'MODE'],
33+
x_vals=[[z, h, 16384 // z, dhead, causal, mode]
34+
for z in [1, 2, 4, 8, 16, 32]
35+
for (h, dhead) in [(16, 128), (32, 64)]
36+
for causal in [True]
37+
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
38+
+ [[4, 48, 1024, 64, True, mode] for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
39+
+ [[z, h, 1024, dhead, True, mode]
40+
for z in [1, 2, 4, 8, 16, 32, 64]
41+
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
42+
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
43+
line_arg='provider',
44+
line_vals=['triton', 'xetla'],
45+
line_names=['Triton', 'XeTLA'],
46+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
47+
ylabel=['GB/s', 'TFlops'],
48+
plot_name='flexAttnCausal-performance',
49+
args={},
50+
))
51+
def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, MODE, provider):
52+
assert MODE in ['fwd', 'bwd']
53+
assert CAUSAL
54+
dtype = torch.float16
55+
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
56+
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
57+
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
58+
sm_scale = 0.125
59+
if MODE == 'bwd':
60+
sm_scale = 1.3
61+
62+
quantiles = [0.5, 0.0, 1.0]
63+
if provider == 'triton':
64+
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX, N_CTX, device=q.device)
65+
triton_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale)
66+
if MODE == 'bwd':
67+
triton_o = triton_fn()
68+
triton_do = torch.randn_like(triton_o)
69+
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
70+
torch_fn = lambda: F.scaled_dot_product_attention(q.cpu(), k.cpu(), v.cpu(), is_causal=True, scale=sm_scale).to(
71+
torch.float32)
72+
if MODE == 'bwd':
73+
torch_o = torch_fn()
74+
torch_do = torch.randn_like(torch_o)
75+
torch_fn = lambda: torch_o.backward(torch_do, retain_graph=True)
76+
if MODE == 'fwd':
77+
atol = 1e-1 if N_CTX == 16384 else 1e-2
78+
benchmark_suit.assert_close(triton_fn, torch_fn, atol=atol, rtol=1e-3, err_msg='triton to torch')
79+
else:
80+
benchmark_suit.assert_close(lambda: triton_o, lambda: torch_o, atol=1e-2, rtol=0, err_msg='triton to torch')
81+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
82+
83+
elif provider == 'xetla':
84+
xetla_fn = None
85+
if MODE == 'fwd':
86+
module_name = 'flash_attn_causal_True'.lower()
87+
func = getattr(xetla_kernel, module_name)
88+
out = torch.empty_like(q, device='xpu', dtype=dtype)
89+
size_score = Z * H * N_CTX * N_CTX
90+
size_attn_mask = Z * N_CTX * N_CTX
91+
dropout_mask = torch.empty((size_score, ), device='xpu', dtype=torch.uint8)
92+
bias = torch.empty((size_attn_mask, ), device='xpu', dtype=dtype)
93+
size_ml = Z * H * N_CTX
94+
m = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
95+
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)
96+
xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
97+
if MODE == 'bwd':
98+
module_name = 'flash_attn_bwd_causal_True'.lower()
99+
func = getattr(xetla_kernel, module_name)
100+
grad_out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
101+
bias = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
102+
dropout = torch.empty_like(q, device='xpu', dtype=torch.uint8)
103+
out = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
104+
log_sumexp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
105+
workspace = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
106+
grad_q_tmp = torch.zeros(q.size(), device='xpu', dtype=dtype, requires_grad=True)
107+
alpha = sm_scale
108+
dropout_prob = 0
109+
grad_query = torch.empty_like(q, device='xpu', dtype=dtype, requires_grad=True)
110+
grad_key = torch.empty_like(k, device='xpu', dtype=dtype, requires_grad=True)
111+
grad_value = torch.empty_like(v, device='xpu', dtype=dtype, requires_grad=True)
112+
grad_bias = torch.empty_like(bias, device='xpu', dtype=dtype, requires_grad=True)
113+
bias_strideB = -1
114+
bias_strideN = -1
115+
bias_strideF = -1
116+
attn_mask_padding = 0
117+
118+
xetla_fn = lambda: func(grad_out, q, k, v, bias, dropout, out, log_sumexp, workspace, grad_q_tmp, alpha,
119+
dropout_prob, grad_query, grad_key, grad_value, grad_bias, Z, H, D_HEAD, N_CTX,
120+
N_CTX, bias_strideB, bias_strideN, bias_strideF, attn_mask_padding)
121+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
122+
123+
else:
124+
raise NotImplementedError(f'Unsupported provider {provider}')
125+
126+
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
127+
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
128+
129+
if MODE == 'bwd':
130+
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
131+
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
132+
133+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
134+
135+
136+
if __name__ == '__main__':
137+
benchmark.run(show_plots=False, print_data=True)
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# This benchmark requires a Pytorch version with FlexAttention support for XPU available
2+
from functools import lru_cache
3+
import os
4+
from torch.nn.attention.flex_attention import (
5+
create_block_mask,
6+
create_mask,
7+
flex_attention,
8+
)
9+
10+
import torch
11+
import torch.nn.functional as F
12+
13+
import triton_kernels_benchmark as benchmark_suit
14+
15+
# Compile the flex_attention function
16+
flex_attention = torch.compile(flex_attention, dynamic=False)
17+
18+
19+
@lru_cache
20+
def create_block_mask_cached(score_mod, B, H, M, N, device='xpu'):
21+
block_mask = create_block_mask(score_mod, B, H, M, N, device=device)
22+
return block_mask
23+
24+
25+
# Default values for NATTEN mask:
26+
# Consider a 2D image of size (G_H x G_W) flattened into a sequence of tokens.
27+
# Queries attend to keys in a fixed kernel area (K_H x K_W)
28+
G_H = 128
29+
G_W = 128
30+
K_H = 13
31+
K_W = 13
32+
33+
34+
def get_x_y(idx):
35+
return idx // G_W, idx % G_W
36+
37+
38+
def natten_mask(_, __, q_idx, kv_idx):
39+
q_x, q_y = get_x_y(q_idx)
40+
kv_x, kv_y = get_x_y(kv_idx)
41+
# kernel nominally attempts to center itself on the query, but kernel center
42+
# is clamped to a fixed distance (kernel half-length) from the canvas edge
43+
kernel_x = q_x.clamp(K_W // 2, (G_W - 1) - K_W // 2)
44+
kernel_y = q_y.clamp(K_H // 2, (G_H - 1) - K_H // 2)
45+
hori_mask = (kernel_x - kv_x).abs() <= K_W // 2
46+
vert_mask = (kernel_y - kv_y).abs() <= K_H // 2
47+
return hori_mask & vert_mask
48+
49+
50+
def alibi_functional(score, _, h, q_idx, kv_idx):
51+
scale = torch.exp2(-((h + 1) * 8.0 / G_H))
52+
bias = (kv_idx - q_idx) * scale
53+
return score + bias
54+
55+
56+
# Kernel profiling for Backward mode is not working as expected:
57+
# For details: https://github.com/pytorch/pytorch/issues/144778
58+
@benchmark_suit.perf_report(
59+
benchmark_suit.Benchmark(
60+
x_names=['Z', 'H', 'N_CTX', 'D_HEAD', 'MASK', 'MODE'],
61+
x_vals=[[z, h, 16384 // z, dhead, mask, mode]
62+
for z in [4, 8, 16, 32]
63+
for (h, dhead) in [(16, 128), (32, 64)]
64+
for mask in ['NATTEN', 'Alibi']
65+
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
66+
+ [[4, 48, 1024, 64, mask, mode]
67+
for mask in ['NATTEN', 'Alibi']
68+
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]] #
69+
+ [[z, h, 1024, dhead, mask, mode]
70+
for z in [1, 2, 4, 8, 16, 32, 64]
71+
for (h, dhead) in [(8, 128), (32, 96), (4, 128)]
72+
for mask in ['NATTEN', 'Alibi']
73+
for mode in [os.getenv('FA_KERNEL_MODE', 'fwd')]],
74+
line_arg='provider',
75+
line_vals=['triton', 'onednn'],
76+
line_names=['Triton', 'OneDNN'],
77+
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
78+
ylabel=['GB/s', 'TFlops'],
79+
plot_name='flexAttnMasks-performance',
80+
args={},
81+
))
82+
def benchmark(Z, H, N_CTX, D_HEAD, MASK, MODE, provider):
83+
assert MODE in ['fwd', 'bwd']
84+
assert MASK in ['NATTEN', 'Alibi']
85+
dtype = torch.float16
86+
q = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
87+
k = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
88+
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype, requires_grad=True)
89+
90+
mask_mod = None
91+
score_mod = None
92+
if MASK == 'NATTEN':
93+
mask_mod = natten_mask
94+
elif MASK == 'Alibi':
95+
score_mod = alibi_functional
96+
97+
if mask_mod is not None:
98+
block_mask = create_block_mask_cached(mask_mod, 1, 1, N_CTX, N_CTX, device=q.device)
99+
else:
100+
block_mask = None
101+
sdpa_mask_fn = mask_mod if mask_mod is not None else score_mod
102+
mask = create_mask(sdpa_mask_fn, 1, 1, N_CTX, N_CTX, device=q.device)
103+
104+
quantiles = [0.5, 0.0, 1.0]
105+
if provider == 'triton':
106+
triton_fn = lambda: flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)
107+
if MODE == 'bwd':
108+
triton_o = triton_fn()
109+
triton_do = torch.randn_like(triton_o)
110+
triton_fn = lambda: triton_o.backward(triton_do, retain_graph=True)
111+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, n_warmup=5, n_repeat=5, quantiles=quantiles)
112+
# Values checking cannot be implemented for these case as :
113+
# "The operator 'aten::_scaled_dot_product_flash_attention_for_cpu' is not currently implemented for the XPU device"
114+
115+
elif provider == 'onednn':
116+
xformers_fn = lambda: F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
117+
if MODE == 'bwd':
118+
xformers_o = xformers_fn()
119+
xformers_do = torch.randn_like(xformers_o)
120+
xformers_fn = lambda: xformers_o.backward(xformers_do, retain_graph=True)
121+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xformers_fn, n_warmup=10, n_repeat=10,
122+
quantiles=quantiles)
123+
124+
else:
125+
raise NotImplementedError(f'Unsupported provider {provider}')
126+
127+
tflops = lambda mean: 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
128+
gbps = lambda mean: Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
129+
130+
if MODE == 'bwd':
131+
tflops = lambda mean: 2.5 * 2 * 2 * Z * H * N_CTX * N_CTX * D_HEAD * (1e-12) / (mean * 1e-3)
132+
gbps = lambda mean: 2.5 * Z * H * (N_CTX * D_HEAD + N_CTX * D_HEAD) * 2 * 2 * (1e-9) / (mean * 1e-3)
133+
134+
return (gbps(mean), gbps(max_ms), gbps(min_ms)), (tflops(mean), tflops(max_ms), tflops(min_ms)), cv
135+
136+
137+
if __name__ == '__main__':
138+
benchmark.run(show_plots=False, print_data=True)

0 commit comments

Comments
 (0)