diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index 0738c6da93..852f9bccdf 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -15,6 +15,7 @@ TEST_CORE=false TEST_INTERPRETER=false TEST_TUTORIAL=false TEST_MICRO_BENCHMARKS=false +TEST_BENCHMARKS=false TEST_BENCHMARK_SOFTMAX=false TEST_BENCHMARK_GEMM=false TEST_BENCHMARK_ATTENTION=false @@ -53,6 +54,10 @@ while [ -v 1 ]; do TEST_MICRO_BENCHMARKS=true shift ;; + --benchmarks) + TEST_BENCHMARKS=true + shift + ;; --softmax) TEST_BENCHMARK_SOFTMAX=true shift @@ -116,7 +121,7 @@ while [ -v 1 ]; do done # Only run interpreter test when $TEST_INTERPRETER is true -if [ "$TEST_UNIT" = false ] && [ "$TEST_CORE" = false ] && [ "$TEST_INTERPRETER" = false ] && [ "$TEST_TUTORIAL" = false ] && [ "$TEST_MICRO_BENCHMARKS" = false ] && [ "$TEST_BENCHMARK_SOFTMAX" = false ] && [ "$TEST_BENCHMARK_GEMM" = false ] && [ "$TEST_BENCHMARK_ATTENTION" = false ] && [ "$TEST_INSTRUMENTATION" = false ] && [ "$TEST_INDUCTOR" = false ]; then +if [ "$TEST_UNIT" = false ] && [ "$TEST_CORE" = false ] && [ "$TEST_INTERPRETER" = false ] && [ "$TEST_TUTORIAL" = false ] && [ "$TEST_MICRO_BENCHMARKS" = false ] && ["$TEST_BENCHMARKS" = false] && [ "$TEST_BENCHMARK_SOFTMAX" = false ] && [ "$TEST_BENCHMARK_GEMM" = false ] && [ "$TEST_BENCHMARK_ATTENTION" = false ] && [ "$TEST_INSTRUMENTATION" = false ] && [ "$TEST_INDUCTOR" = false ]; then TEST_UNIT=true TEST_CORE=true TEST_TUTORIAL=true @@ -152,7 +157,7 @@ install_deps() { echo "**** Skipping installation of pytorch ****" else echo "**** Installing pytorch ****" - if ([ ! -v USE_IPEX ] || [ "$USE_IPEX" = 1 ]) && ([ "$TEST_BENCHMARK_SOFTMAX" = true ] || [ "$TEST_BENCHMARK_GEMM" = true ] || [ "$TEST_BENCHMARK_ATTENTION" = true ]); then + if ([ ! -v USE_IPEX ] || [ "$USE_IPEX" = 1 ]) && ([ "$TEST_BENCHMARKS" = true ] || [ "$TEST_BENCHMARK_SOFTMAX" = true ] || [ "$TEST_BENCHMARK_GEMM" = true ] || [ "$TEST_BENCHMARK_ATTENTION" = true ]); then $SCRIPTS_DIR/compile-pytorch-ipex.sh $([ $VENV = true ] && echo "--venv") else $SCRIPTS_DIR/install-pytorch.sh $([ $VENV = true ] && echo "--venv") @@ -288,6 +293,21 @@ run_benchmark_attention() { python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py } +run_benchmarks() { + cd $TRITON_PROJ/benchmarks + python setup.py install + for file in $TRITON_PROJ/benchmarks/triton_kernels_benchmark/*.py; do + benchmark=$(basename -- "$file" .py) + if [[ $benchmark = @("__init__"|"benchmark_driver"|"benchmark_testing") ]]; then + continue + fi + echo + echo "****** Running ${benchmark} ******" + echo + python $file + done +} + run_instrumentation_tests() { # FIXME: the "instrumentation" test suite currently contains only one test, when all tests # are skipped pytest reports an error. If the only test is the skip list, then we shouldn't @@ -339,6 +359,9 @@ test_triton() { if [ "$TEST_MICRO_BENCHMARKS" = true ]; then run_microbench_tests fi + if [ "$TEST_BENCHMARKS" = true ]; then + run_benchmarks + fi if [ "$TEST_BENCHMARK_SOFTMAX" = true ]; then run_benchmark_softmax fi