1717 --benchmarks
1818 --softmax
1919 --gemm
20- --attention
20+ --flash-attention
21+ --flex-attention
2122 --instrumentation
2223 --inductor
2324 --sglang
@@ -55,7 +56,8 @@ TEST_MICRO_BENCHMARKS=false
5556TEST_BENCHMARKS=false
5657TEST_BENCHMARK_SOFTMAX=false
5758TEST_BENCHMARK_GEMM=false
58- TEST_BENCHMARK_ATTENTION=false
59+ TEST_BENCHMARK_FLASH_ATTENTION=false
60+ TEST_BENCHMARK_FLEX_ATTENTION=false
5961TEST_INSTRUMENTATION=false
6062TEST_INDUCTOR=false
6163TEST_SGLANG=false
@@ -128,8 +130,13 @@ while (( $# != 0 )); do
128130 TEST_DEFAULT=false
129131 shift
130132 ;;
131- --attention)
132- TEST_BENCHMARK_ATTENTION=true
133+ --flash-attention)
134+ TEST_BENCHMARK_FLASH_ATTENTION=true
135+ TEST_DEFAULT=false
136+ shift
137+ ;;
138+ --flex-attention)
139+ TEST_BENCHMARK_FLEX_ATTENTION=true
133140 TEST_DEFAULT=false
134141 shift
135142 ;;
@@ -410,9 +417,9 @@ run_benchmark_gemm() {
410417 python $TRITON_PROJ /benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py
411418}
412419
413- run_benchmark_attention () {
420+ run_benchmark_flash_attention () {
414421 echo " ****************************************************"
415- echo " ***** Running ATTENTION *****"
422+ echo " ***** Running FlashAttention *****"
416423 echo " ****************************************************"
417424 cd $TRITON_PROJ /benchmarks
418425 pip install .
@@ -433,6 +440,17 @@ run_benchmark_attention() {
433440 python $TRITON_PROJ /benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py
434441}
435442
443+ run_benchmark_flex_attention () {
444+ echo " ****************************************************"
445+ echo " ***** Running FlexAttention *****"
446+ echo " ****************************************************"
447+ cd $TRITON_PROJ /benchmarks
448+ pip install .
449+
450+ echo " FlexAttention - causal mask:"
451+ python $TRITON_PROJ /benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
452+ }
453+
436454run_benchmarks () {
437455 cd $TRITON_PROJ /benchmarks
438456 pip install .
@@ -538,8 +556,11 @@ test_triton() {
538556 if [ " $TEST_BENCHMARK_GEMM " = true ]; then
539557 run_benchmark_gemm
540558 fi
541- if [ " $TEST_BENCHMARK_ATTENTION " = true ]; then
542- run_benchmark_attention
559+ if [ " $TEST_BENCHMARK_FLASH_ATTENTION " = true ]; then
560+ run_benchmark_flash_attention
561+ fi
562+ if [ " $TEST_BENCHMARK_FLEX_ATTENTION " = true ]; then
563+ run_benchmark_flex_attention
543564 fi
544565 if [ " $TEST_INSTRUMENTATION " == true ]; then
545566 run_instrumentation_tests
0 commit comments