Skip to content

Commit 0ab03be

Browse files
[CI] Run Flex Attention with batch size 16 (#4908)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent bd11640 commit 0ab03be

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

.github/workflows/triton-benchmarks.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,17 @@ jobs:
312312
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flex-attn-causal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
313313
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-torch-report.csv --benchmark flex-attn-causal --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
314314
315+
- name: Run Triton FlexAttention (batch_size=16) Causal Mask fwd kernel benchmark
316+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py') }}
317+
run: |
318+
export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
319+
cd benchmarks/triton_kernels_benchmark
320+
BATCH_SIZE=16 python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS
321+
322+
source ../../scripts/capture-hw-details.sh
323+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-triton-report.csv --benchmark flex-attn-causal-batch16 --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
324+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch16-torch-report.csv --benchmark flex-attn-causal-batch16 --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
325+
315326
- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
316327
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}
317328
run: |

benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def causal_mask(_, __, q_idx, kv_idx):
7272

7373

7474
throughput_test = os.getenv('THROUGHPUT_TEST', '0') == '1'
75-
batch_sizes = [16, 32, 64] if throughput_test else [1]
75+
batch_size = int(os.getenv('BATCH_SIZE', '1'))
76+
batch_sizes = [16, 32, 64] if throughput_test else [batch_size]
7677

7778

7879
# Kernel profiling for Backward mode is not working as expected:

0 commit comments

Comments
 (0)