Skip to content

Commit 4ea9537

Browse files
Add Flash Attention backward to benchmarks/triton_kernels_benchmark (#3108)
Co-authored-by: Whitney Tsang <[email protected]>
1 parent 73b9356 commit 4ea9537

File tree

5 files changed

+669
-302
lines changed

5 files changed

+669
-302
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,27 +222,39 @@ jobs:
222222
source ../../scripts/capture-hw-details.sh
223223
python ../../scripts/build_report.py $REPORTS/matmul-performance-postop-addmatrix-int8.csv $REPORTS/gemm-postop-addmatrix-int8-triton-report.csv --benchmark gemm-postop-addmatrix-int8 --compiler triton --param_cols "B,M,K,N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
224224
225-
- name: Run Triton FA kernel benchmark
226-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py') }}
225+
- name: Run Triton FA fwd kernel benchmark
226+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_benchmark.py') }}
227227
run: |
228228
cd benchmarks/triton_kernels_benchmark
229-
python flash_attention_fwd_benchmark.py --reports $REPORTS
229+
python flash_attention_benchmark.py --reports $REPORTS
230230
231231
source ../../scripts/capture-hw-details.sh
232-
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
233-
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
232+
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
233+
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
234234
235-
- name: Run Triton FA kernel benchmark - advanced path
236-
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_fwd_benchmark.py_advanced') }}
235+
- name: Run Triton FA fwd kernel benchmark - advanced path
236+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_benchmark.py_advanced') }}
237237
run: |
238238
cd benchmarks/triton_kernels_benchmark
239239
TRITON_INTEL_ADVANCED_PATH=1 \
240240
IGC_VISAOptions=" -enableBCR" \
241-
python flash_attention_fwd_benchmark.py --reports $REPORTS
241+
python flash_attention_benchmark.py --reports $REPORTS
242242
243243
TAG="${TAG}-adv"
244244
source ../../scripts/capture-hw-details.sh
245-
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-advanced-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
245+
python ../../scripts/build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-advanced-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
246+
247+
- name: Run Triton FA bwd kernel benchmark
248+
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_benchmark.py') }}
249+
run: |
250+
cd benchmarks/triton_kernels_benchmark
251+
FA_KERNEL_MODE="bwd" \
252+
BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_benchmark.py --reports $REPORTS
253+
mv $REPORTS/attn-performance.csv $REPORTS/attn-bwd-performance.csv
254+
255+
source ../../scripts/capture-hw-details.sh
256+
python ../../scripts/build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
257+
python ../../scripts/build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL,MODE" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
246258
247259
- name: Run Prefix Sums kernel benchmark
248260
if: ${{ steps.install.outcome == 'success' && !cancelled() && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }}

0 commit comments

Comments
 (0)