Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,18 @@ jobs:
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-triton-report.csv --benchmark flex-attn-causal-batch16 --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-torch-report.csv --benchmark flex-attn-causal-batch16 --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG

- name: Run Triton FlexAttention Causal Mask bwd kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_bwd_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_bwd_benchmark_causal_mask.py') }}
run: |
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
cd benchmarks/triton_kernels_benchmark
FA_KERNEL_MODE='bwd' \
python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS

source ../../scripts/capture-hw-details.sh
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-bwd-triton-report.csv --benchmark flex-attn-causal-bwd --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-bwd-torch-report.csv --benchmark flex-attn-causal-bwd --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG

- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,53 +86,58 @@ def causal_mask(_, __, q_idx, kv_idx):
@benchmark_suite.perf_report(
benchmark_suite.Benchmark(
x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'],
x_vals=
# Multi-head attention. H_q equals H_kv
# Prefill shapes of Phi3-mini-4k-instruct
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of Qwen3-4B
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of DeepSeek-v3
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of Phi3-mini-4k-instruct
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +

## Multi-query attention. H_kv equals 1.
# Append shapes of Deepseek-v3
[[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode] for z in batch_sizes] +

# Grouped-query attention. H_q / H_kv > 1
# Prefill shapes of Llama-3.1-8B
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of meta-llama-Llama-3.2-3B
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of Llama-3.1-8B
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of meta-llama-Llama-3.2-3B
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of Qwen3-4B
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +

# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
# Decode shapes of Llama-3.1-8B
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of meta-llama-Llama-3.2-3B
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of Phi3-mini-4k-instruct
[
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
# ValueError: Shape element 2 must be a power of 2
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
] +
# Decode shapes of Qwen3-4B
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of Deepseek-v3
[
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
x_vals=[
x_val for x_val in
# Multi-head attention. H_q equals H_kv
# Prefill shapes of Phi3-mini-4k-instruct
[[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of Qwen3-4B
[[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of DeepSeek-v3
[[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of Phi3-mini-4k-instruct
[[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] +

# Multi-query attention. H_kv equals 1.
# Append shapes of Deepseek-v3
([[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode]
for z in batch_sizes] if fa_kernel_mode != 'bwd' else []) +

# Grouped-query attention. H_q / H_kv > 1
# Prefill shapes of Llama-3.1-8B
[[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of meta-llama-Llama-3.2-3B
[[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Prefill shapes of Deepseek-R1-Distill-Qwen-14B
[[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of Llama-3.1-8B
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of meta-llama-Llama-3.2-3B
[[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Append shapes of Qwen3-4B
[[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] +

# FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k
# Decode shapes of Llama-3.1-8B
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of meta-llama-Llama-3.2-3B
[[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of Phi3-mini-4k-instruct
[
# acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
# ValueError: Shape element 2 must be a power of 2
# [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes
] +
# Decode shapes of Qwen3-4B
[[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of Deepseek-R1-Distill-Qwen-14B
[[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] +
# Decode shapes of Deepseek-v3
[
# [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes
]
# FIXME: Reenable when PyTorch fixes config in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/triton.py#L1509
if x_val[-1] == 'fwd' or x_val[5] != 128
],
line_arg='provider',
line_vals=['triton', 'torch'],
Expand All @@ -143,6 +148,7 @@ def causal_mask(_, __, q_idx, kv_idx):
args={},
))
def benchmark(Z, H_q, H_kv, N_CTX_q, N_CTX_kv, D_HEAD_qk, D_HEAD_v, MODE, provider):
print(f'Running benchmark for Z={Z}, H_q={H_q}, H_kv={H_kv}, N_CTX_q={N_CTX_q}, N_CTX_kv={N_CTX_kv}, ')
# Maximum across torch=200, triton=600
do_bench = benchmark_suite.get_do_bench(n_warmup=600, n_repeat=10, quantiles=[0.5, 0.0, 1.0])
if MODE not in ('fwd', 'bwd'):
Expand Down
Loading