Skip to content

Commit 11cab64

Browse files
[FlexAttention] Add torch as reference (#4251)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent cd6042b commit 11cab64

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ jobs:
278278
279279
source ../../scripts/capture-hw-details.sh
280280
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flex-attn-causal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
281+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-torch-report.csv --benchmark flex-attn-causal --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
281282
282283
- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
283284
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def causal_mask(_, __, q_idx, kv_idx):
8787
# Decode shapes of Deepseek-v3 (Rope)
8888
[],
8989
line_arg='provider',
90-
line_vals=['triton'],
91-
line_names=['Triton'],
90+
line_vals=['triton', 'torch'],
91+
line_names=['Triton', 'Torch'],
9292
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],
9393
ylabel=['GB/s', 'TFlops'],
9494
plot_name='flexAttnCausal-performance',
@@ -105,12 +105,16 @@ def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provid
105105
sm_scale = 1.3
106106

107107
quantiles = [0.5, 0.0, 1.0]
108-
if provider == 'triton':
108+
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
109+
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
110+
111+
if provider == 'torch':
112+
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(torch_fn, n_warmup=10, n_repeat=10, quantiles=quantiles)
113+
114+
elif provider == 'triton':
109115
kernel_options = {'num_stages': 2, 'num_warps': 16 if D_HEAD_qk == 128 else 8, 'BLOCKS_ARE_CONTIGUOUS': True}
110-
block_mask = create_block_mask_cached(causal_mask, 1, 1, N_CTX_q, N_CTX_kv, device='xpu')
111116
triton_fn = lambda: compiled_flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=(
112117
not H_q == H_kv), kernel_options=kernel_options)
113-
torch_fn = lambda: flex_attention(q, k, v, block_mask=block_mask, scale=sm_scale, enable_gqa=not H_q == H_kv)
114118
if MODE == 'bwd':
115119
triton_o = triton_fn()
116120
triton_do = torch.randn_like(triton_o)

0 commit comments

Comments
 (0)