@@ -312,6 +312,17 @@ jobs:
312
312
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flex-attn-causal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
313
313
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-torch-report.csv --benchmark flex-attn-causal --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
314
314
315
+ - name : Run Triton FlexAttention (batch_size=4) Causal Mask fwd kernel benchmark
316
+ if : ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch4-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch4-causal_mask.py') }}
317
+ run : |
318
+ export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH
319
+ cd benchmarks/triton_kernels_benchmark
320
+ BATCH_SIZE=4 python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS
321
+
322
+ source ../../scripts/capture-hw-details.sh
323
+ python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch4-triton-report.csv --benchmark flex-attn-causal-batch4 --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
324
+ python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-batch4-torch-report.csv --benchmark flex-attn-causal-batch4 --compiler torch --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Torch-TFlops --hbm_col "Torch-GB/s" --tag $TAG
325
+
315
326
- name : Run Triton FlexAttention (batch_size=16) Causal Mask fwd kernel benchmark
316
327
if : ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py') }}
317
328
run : |
0 commit comments