Skip to content

Commit f697dde

Browse files
[0.6.0-UT] Transferring Multi-GPU tests (#527)
1 parent 41a32e3 commit f697dde

File tree

4 files changed

+651
-307
lines changed

4 files changed

+651
-307
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2025 The JAX Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Configuration file containing the list of multi-GPU test files.
18+
This list is shared between run_single_gpu.py and run_multi_gpu.sh
19+
to ensure consistency and avoid duplication.
20+
"""
21+
22+
# Multi-GPU test files that should be excluded from single GPU runs
23+
# but included in multi-GPU test runs
24+
MULTI_GPU_TESTS = {
25+
"tests/multiprocess_gpu_test.py",
26+
"tests/debug_info_test.py",
27+
"tests/checkify_test.py",
28+
"tests/mosaic/gpu_test.py",
29+
"tests/random_test.py",
30+
"tests/jax_jit_test.py",
31+
"tests/mesh_utils_test.py",
32+
"tests/pjit_test.py",
33+
"tests/linalg_sharding_test.py",
34+
"tests/multi_device_test.py",
35+
"tests/distributed_test.py",
36+
"tests/shard_alike_test.py",
37+
"tests/api_test.py",
38+
"tests/ragged_collective_test.py",
39+
"tests/batching_test.py",
40+
"tests/scaled_matmul_stablehlo_test.py",
41+
"tests/export_harnesses_multi_platform_test.py",
42+
"tests/pickle_test.py",
43+
"tests/roofline_test.py",
44+
"tests/profiler_test.py",
45+
"tests/error_check_test.py",
46+
"tests/debug_nans_test.py",
47+
"tests/shard_map_test.py",
48+
"tests/colocated_python_test.py",
49+
"tests/cudnn_fusion_test.py",
50+
"tests/compilation_cache_test.py",
51+
"tests/export_back_compat_test.py",
52+
"tests/pgle_test.py",
53+
"tests/ffi_test.py",
54+
"tests/lax_control_flow_test.py",
55+
"tests/fused_attention_stablehlo_test.py",
56+
"tests/layout_test.py",
57+
"tests/pmap_test.py",
58+
"tests/aot_test.py",
59+
"tests/mock_gpu_topology_test.py",
60+
"tests/ann_test.py",
61+
"tests/debugging_primitives_test.py",
62+
"tests/array_test.py",
63+
"tests/export_test.py",
64+
"tests/memories_test.py",
65+
"tests/debugger_test.py",
66+
"tests/python_callback_test.py",
67+
}
68+
69+
if __name__ == "__main__":
70+
import sys
71+
72+
if len(sys.argv) > 1 and sys.argv[1] == "--list":
73+
# Print all multi-GPU tests, one per line (for bash scripts)
74+
for test in MULTI_GPU_TESTS:
75+
print(test)
76+
else:
77+
print("Multi-GPU tests configuration")
78+
print(f"Total tests: {len(MULTI_GPU_TESTS)}")
79+
print("\nTo get the list for bash scripts, run:")
80+
print("python3 multi_gpu_tests_config.py --list")
81+

build/rocm/run_multi_gpu.sh

100755100644
Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ detect_amd_gpus() {
2828
echo "Error: lspci command not found. Aborting."
2929
exit 1
3030
fi
31-
# Count AMD/ATI GPU controllers.
31+
# Count AMD GPUs.
3232
local count
33-
count=$(lspci | grep -c 'controller.*AMD/ATI')
33+
count=$(rocm-smi | grep -E '^Device' -A 1000 | awk '$1 ~ /^[0-9]+$/ {count++} END {print count}')
3434
echo "$count"
3535
}
3636

@@ -52,19 +52,39 @@ run_tests() {
5252
# Create the log directory if it doesn't exist.
5353
mkdir -p "$LOG_DIR"
5454

55-
python3 -m pytest \
56-
--html="${LOG_DIR}/multi_gpu_pmap_test_log.html" \
57-
--json-report \
58-
--json-report-file="${LOG_DIR}/multi_gpu_pmap_test_log.json" \
59-
--reruns 3 \
60-
tests/pmap_test.py
61-
62-
python3 -m pytest \
63-
--html="${LOG_DIR}/multi_gpu_multi_device_test_log.html" \
64-
--json-report \
65-
--json-report-file="${LOG_DIR}/multi_gpu_multi_device_test_log.json" \
66-
--reruns 3 \
67-
tests/multi_device_test.py
55+
# Multi-GPU test files - load from shared configuration
56+
echo "Loading multi-GPU tests from configuration..."
57+
58+
# Set the path to the configuration file
59+
CONFIG_PATH="build/rocm/multi_gpu_tests_config.py"
60+
61+
# Check if python3 and the config file are available
62+
if ! python3 -c "import sys; sys.path.insert(0, 'build/rocm'); import multi_gpu_tests_config" 2>/dev/null; then
63+
echo "Error: multi_gpu_tests_config.py not found in build/rocm/ or not importable. Aborting."
64+
exit 1
65+
fi
66+
67+
# Load the multi-GPU tests from the configuration file
68+
mapfile -t MULTI_GPU_TESTS < <(python3 "$CONFIG_PATH" --list)
69+
70+
# Run each multi-GPU test
71+
for test_file in "${MULTI_GPU_TESTS[@]}"; do
72+
test_name=$(basename "$test_file" .py)
73+
echo "Running multi-GPU test: $test_file"
74+
75+
# Define file paths for abort detection (files created by conftest.py)
76+
json_log_file="${LOG_DIR}/multi_gpu_${test_name}_log.json"
77+
html_log_file="${LOG_DIR}/multi_gpu_${test_name}_log.html"
78+
79+
# Run the test
80+
python3 -m pytest \
81+
--html="$html_log_file" \
82+
--json-report \
83+
--json-report-file="$json_log_file" \
84+
--reruns 3 \
85+
"$test_file"
86+
87+
done
6888

6989
# Merge individual HTML reports into one.
7090
python3 -m pytest_html_merger \

0 commit comments

Comments
 (0)