Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ TEST_CORE=false
TEST_INTERPRETER=false
TEST_TUTORIAL=false
TEST_MICRO_BENCHMARKS=false
TEST_BENCHMARKS=false
TEST_BENCHMARK_SOFTMAX=false
TEST_BENCHMARK_GEMM=false
TEST_BENCHMARK_ATTENTION=false
Expand Down Expand Up @@ -53,6 +54,10 @@ while [ -v 1 ]; do
TEST_MICRO_BENCHMARKS=true
shift
;;
--benchmarks)
TEST_BENCHMARKS=true
shift
;;
--softmax)
TEST_BENCHMARK_SOFTMAX=true
shift
Expand Down Expand Up @@ -116,7 +121,7 @@ while [ -v 1 ]; do
done

# Only run interpreter test when $TEST_INTERPRETER is true
if [ "$TEST_UNIT" = false ] && [ "$TEST_CORE" = false ] && [ "$TEST_INTERPRETER" = false ] && [ "$TEST_TUTORIAL" = false ] && [ "$TEST_MICRO_BENCHMARKS" = false ] && [ "$TEST_BENCHMARK_SOFTMAX" = false ] && [ "$TEST_BENCHMARK_GEMM" = false ] && [ "$TEST_BENCHMARK_ATTENTION" = false ] && [ "$TEST_INSTRUMENTATION" = false ] && [ "$TEST_INDUCTOR" = false ]; then
if [ "$TEST_UNIT" = false ] && [ "$TEST_CORE" = false ] && [ "$TEST_INTERPRETER" = false ] && [ "$TEST_TUTORIAL" = false ] && [ "$TEST_MICRO_BENCHMARKS" = false ] && ["$TEST_BENCHMARKS" = false] && [ "$TEST_BENCHMARK_SOFTMAX" = false ] && [ "$TEST_BENCHMARK_GEMM" = false ] && [ "$TEST_BENCHMARK_ATTENTION" = false ] && [ "$TEST_INSTRUMENTATION" = false ] && [ "$TEST_INDUCTOR" = false ]; then
TEST_UNIT=true
TEST_CORE=true
TEST_TUTORIAL=true
Expand Down Expand Up @@ -152,7 +157,7 @@ install_deps() {
echo "**** Skipping installation of pytorch ****"
else
echo "**** Installing pytorch ****"
if ([ ! -v USE_IPEX ] || [ "$USE_IPEX" = 1 ]) && ([ "$TEST_BENCHMARK_SOFTMAX" = true ] || [ "$TEST_BENCHMARK_GEMM" = true ] || [ "$TEST_BENCHMARK_ATTENTION" = true ]); then
if ([ ! -v USE_IPEX ] || [ "$USE_IPEX" = 1 ]) && ([ "$TEST_BENCHMARKS" = true ] || [ "$TEST_BENCHMARK_SOFTMAX" = true ] || [ "$TEST_BENCHMARK_GEMM" = true ] || [ "$TEST_BENCHMARK_ATTENTION" = true ]); then
$SCRIPTS_DIR/compile-pytorch-ipex.sh $([ $VENV = true ] && echo "--venv")
else
$SCRIPTS_DIR/install-pytorch.sh $([ $VENV = true ] && echo "--venv")
Expand Down Expand Up @@ -288,6 +293,21 @@ run_benchmark_attention() {
python $TRITON_PROJ/benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py
}

run_benchmarks() {
cd $TRITON_PROJ/benchmarks
python setup.py install
for file in $TRITON_PROJ/benchmarks/triton_kernels_benchmark/*.py; do
benchmark=$(basename -- "$file" .py)
if [[ $benchmark = @("__init__"|"benchmark_driver"|"benchmark_testing") ]]; then
continue
fi
echo
echo "****** Running ${benchmark} ******"
echo
python $file
done
}

run_instrumentation_tests() {
# FIXME: the "instrumentation" test suite currently contains only one test, when all tests
# are skipped pytest reports an error. If the only test is the skip list, then we shouldn't
Expand Down Expand Up @@ -339,6 +359,9 @@ test_triton() {
if [ "$TEST_MICRO_BENCHMARKS" = true ]; then
run_microbench_tests
fi
if [ "$TEST_BENCHMARKS" = true ]; then
run_benchmarks
fi
if [ "$TEST_BENCHMARK_SOFTMAX" = true ]; then
run_benchmark_softmax
fi
Expand Down