Skip to content

Commit b0acd14

Browse files
[CI] Run Flex Attention with batch size 4 (#4913)
Flex Attention with batch size 16 (Torch implementation) fails on BMG: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17038905956/job/48297736887. This PR skips running batch size 16 on BMG. In order to track the performance of Flex Attention with more than 1 batch size, this PR adds a run with batch size 4, which can be removed when batch size 16 is fixed. Signed-off-by: Whitney Tsang <[email protected]>
1 parent b58720a commit b0acd14

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

.github/workflows/triton-benchmarks-bmg.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ jobs:
1515
uses: ./.github/workflows/triton-benchmarks.yml
1616
with:
1717
runner_label: b580
18-
skip_benchmarks: "[]"
18+
skip_benchmarks: "['flex_attention_benchmark_batch16-causal_mask.py']"

.github/workflows/triton-benchmarks.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ jobs:
312312
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flex-attn-causal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
313313
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-torch-report.csv --benchmark flex-attn-causal --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
314314
315+
- name: Run Triton FlexAttention (batch_size=4) Causal Mask fwd kernel benchmark
316+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch4-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch4-causal_mask.py') }}
317+
run: |
318+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
319+
cd benchmarks/triton_kernels_benchmark
320+
BATCH_SIZE=4 python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS
321+
322+
source ../../scripts/capture-hw-details.sh
323+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch4-triton-report.csv --benchmark flex-attn-causal-batch4 --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
324+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch4-torch-report.csv --benchmark flex-attn-causal-batch4 --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
325+
315326
- name: Run Triton FlexAttention (batch_size=16) Causal Mask fwd kernel benchmark
316327
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py') }}
317328
run: |

0 commit comments

Comments
 (0)