Skip to content

Commit 7597ef1

Browse files
authored
Add FlexAttention to test script (#4950)
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 84fd610 commit 7597ef1

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

scripts/test-triton.sh

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ TEST:
1717
--benchmarks
1818
--softmax
1919
--gemm
20-
--attention
20+
--flash-attention
21+
--flex-attention
2122
--instrumentation
2223
--inductor
2324
--sglang
@@ -55,7 +56,8 @@ TEST_MICRO_BENCHMARKS=false
5556
TEST_BENCHMARKS=false
5657
TEST_BENCHMARK_SOFTMAX=false
5758
TEST_BENCHMARK_GEMM=false
58-
TEST_BENCHMARK_ATTENTION=false
59+
TEST_BENCHMARK_FLASH_ATTENTION=false
60+
TEST_BENCHMARK_FLEX_ATTENTION=false
5961
TEST_INSTRUMENTATION=false
6062
TEST_INDUCTOR=false
6163
TEST_SGLANG=false
@@ -128,8 +130,13 @@ while (( $# != 0 )); do
128130
TEST_DEFAULT=false
129131
shift
130132
;;
131-
--attention)
132-
TEST_BENCHMARK_ATTENTION=true
133+
--flash-attention)
134+
TEST_BENCHMARK_FLASH_ATTENTION=true
135+
TEST_DEFAULT=false
136+
shift
137+
;;
138+
--flex-attention)
139+
TEST_BENCHMARK_FLEX_ATTENTION=true
133140
TEST_DEFAULT=false
134141
shift
135142
;;
@@ -410,9 +417,9 @@ run_benchmark_gemm() {
410417
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py
411418
}
412419

413-
run_benchmark_attention() {
420+
run_benchmark_flash_attention() {
414421
echo "****************************************************"
415-
echo "***** Running ATTENTION *****"
422+
echo "***** Running FlashAttention *****"
416423
echo "****************************************************"
417424
cd $TRITON_PROJ/benchmarks
418425
pip install .
@@ -433,6 +440,17 @@ run_benchmark_attention() {
433440
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py
434441
}
435442

443+
run_benchmark_flex_attention() {
444+
echo "****************************************************"
445+
echo "***** Running FlexAttention *****"
446+
echo "****************************************************"
447+
cd $TRITON_PROJ/benchmarks
448+
pip install .
449+
450+
echo "FlexAttention - causal mask:"
451+
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flex_attention_benchmark_causal_mask.py
452+
}
453+
436454
run_benchmarks() {
437455
cd $TRITON_PROJ/benchmarks
438456
pip install .
@@ -538,8 +556,11 @@ test_triton() {
538556
if [ "$TEST_BENCHMARK_GEMM" = true ]; then
539557
run_benchmark_gemm
540558
fi
541-
if [ "$TEST_BENCHMARK_ATTENTION" = true ]; then
542-
run_benchmark_attention
559+
if [ "$TEST_BENCHMARK_FLASH_ATTENTION" = true ]; then
560+
run_benchmark_flash_attention
561+
fi
562+
if [ "$TEST_BENCHMARK_FLEX_ATTENTION" = true ]; then
563+
run_benchmark_flex_attention
543564
fi
544565
if [ "$TEST_INSTRUMENTATION" == true ]; then
545566
run_instrumentation_tests

0 commit comments

Comments
 (0)