diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index a28fe59d6a..9cdf69736a 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -334,6 +334,18 @@ jobs: python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-triton-report.csv --benchmark flex-attn-causal-batch16 --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-torch-report.csv --benchmark flex-attn-causal-batch16 --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG + - name: Run Triton FlexAttention Causal Mask bwd kernel benchmark + if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_bwd_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_bwd_benchmark_causal_mask.py') }} + run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + cd benchmarks/triton_kernels_benchmark + FA_KERNEL_MODE='bwd' \ + python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS + + source ../../scripts/capture-hw-details.sh + python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-bwd-triton-report.csv --benchmark flex-attn-causal-bwd --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG + python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-bwd-torch-report.csv --benchmark flex-attn-causal-bwd --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG + - name: Run Triton FlexAttention Custom Masks fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }} run: | diff --git a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py index e08a0135bd..3b4871d073 100644 --- a/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py +++ b/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py @@ -86,53 +86,58 @@ def causal_mask(_, __, q_idx, kv_idx): @benchmark_suite.perf_report( benchmark_suite.Benchmark( x_names=['Z', 'H_q', 'H_kv', 'N_CTX_q', 'N_CTX_kv', 'D_HEAD_qk', 'D_HEAD_v', 'MODE'], - x_vals= - # Multi-head attention. H_q equals H_kv - # Prefill shapes of Phi3-mini-4k-instruct - [[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] + - # Prefill shapes of Qwen3-4B - [[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Prefill shapes of DeepSeek-v3 - [[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] + - # Append shapes of Phi3-mini-4k-instruct - [[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] + - - ## Multi-query attention. H_kv equals 1. - # Append shapes of Deepseek-v3 - [[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode] for z in batch_sizes] + - - # Grouped-query attention. H_q / H_kv > 1 - # Prefill shapes of Llama-3.1-8B - [[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Prefill shapes of meta-llama-Llama-3.2-3B - [[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Prefill shapes of Deepseek-R1-Distill-Qwen-14B - [[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Append shapes of Llama-3.1-8B - [[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Append shapes of meta-llama-Llama-3.2-3B - [[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Append shapes of Qwen3-4B - [[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] + - - # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k - # Decode shapes of Llama-3.1-8B - [[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Decode shapes of meta-llama-Llama-3.2-3B - [[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Decode shapes of Phi3-mini-4k-instruct - [ - # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) - # ValueError: Shape element 2 must be a power of 2 - # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes - ] + - # Decode shapes of Qwen3-4B - [[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Decode shapes of Deepseek-R1-Distill-Qwen-14B - [[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + - # Decode shapes of Deepseek-v3 - [ - # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes + x_vals=[ + x_val for x_val in + # Multi-head attention. H_q equals H_kv + # Prefill shapes of Phi3-mini-4k-instruct + [[z, 32, 32, 1024, 1024, 96, 96, fa_kernel_mode] for z in batch_sizes] + + # Prefill shapes of Qwen3-4B + [[z, 32, 32, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Prefill shapes of DeepSeek-v3 + [[z, 128, 128, 1024, 1024, 192, 128, fa_kernel_mode] for z in batch_sizes] + + # Append shapes of Phi3-mini-4k-instruct + [[z, 32, 32, 512, 1024 + 128 + 512, 96, 96, fa_kernel_mode] for z in batch_sizes] + + + # Multi-query attention. H_kv equals 1. + # Append shapes of Deepseek-v3 + ([[z, 128, 1, 512, 1024 + 128 + 512, 576, 512, fa_kernel_mode] + for z in batch_sizes] if fa_kernel_mode != 'bwd' else []) + + + # Grouped-query attention. H_q / H_kv > 1 + # Prefill shapes of Llama-3.1-8B + [[z, 32, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Prefill shapes of meta-llama-Llama-3.2-3B + [[z, 24, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Prefill shapes of Deepseek-R1-Distill-Qwen-14B + [[z, 40, 8, 1024, 1024, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Append shapes of Llama-3.1-8B + [[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Append shapes of meta-llama-Llama-3.2-3B + [[z, 24, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Append shapes of Qwen3-4B + [[z, 32, 8, 512, 1024 + 128 + 512, 128, 128, fa_kernel_mode] for z in batch_sizes] + + + # FlexDecoding configuration. N_CTX_q equals 1. N_CTX_kv >= 1k + # Decode shapes of Llama-3.1-8B + [[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Decode shapes of meta-llama-Llama-3.2-3B + [[z, 24, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Decode shapes of Phi3-mini-4k-instruct + [ + # acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + # ValueError: Shape element 2 must be a power of 2 + # [z, 32, 32, 1, 1024 + 64, 96, 96, fa_kernel_mode] for z in batch_sizes + ] + + # Decode shapes of Qwen3-4B + [[z, 32, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Decode shapes of Deepseek-R1-Distill-Qwen-14B + [[z, 40, 8, 1, 1024 + 64, 128, 128, fa_kernel_mode] for z in batch_sizes] + + # Decode shapes of Deepseek-v3 + [ + # [z, 128, 1, 1, 1024 + 64, 576, 512, fa_kernel_mode] for z in batch_sizes + ] + # FIXME: Reenable when PyTorch fixes config in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/template_heuristics/triton.py#L1509 + if x_val[-1] == 'fwd' or x_val[5] != 128 ], line_arg='provider', line_vals=['triton', 'torch'],