Skip to content

Commit 98be5a6

Browse files
Copilotmawad-amd
andcommitted
Simplify distributed test runner to use torchrun directly
Refactored run_tests_distributed.py to remove recursive launcher pattern: - Removed launcher logic that parsed arguments and spawned torchrun - Script now runs directly as a torchrun worker - Takes pytest arguments directly from command line (no environment variable) Updated run_tests.sh to invoke torchrun directly: - Changed from: python run_tests_distributed.py --num_ranks N <test> <args> - Changed to: torchrun --nproc_per_node=N --standalone run_tests_distributed.py <test> <args> Benefits: - Simpler, more direct execution path - No recursive script invocation - Easier to understand and debug - Eliminates dummy arguments and environment variable passing Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
1 parent 712f47c commit 98be5a6

File tree

2 files changed

+33
-103
lines changed

2 files changed

+33
-103
lines changed

.github/scripts/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ EXIT_CODE=0
105105
for test_file in tests/$TEST_DIR/test_*.py; do
106106
if [ -f \"\$test_file\" ]; then
107107
echo \"Testing: \$test_file with $NUM_RANKS ranks (install: $INSTALL_METHOD)\"
108-
python tests/run_tests_distributed.py --num_ranks $NUM_RANKS \"\$test_file\" -v --tb=short --durations=10
108+
torchrun --nproc_per_node=$NUM_RANKS --standalone tests/run_tests_distributed.py \"\$test_file\" -v --tb=short --durations=10
109109
fi
110110
done
111111
" || { EXIT_CODE=$?; }

tests/run_tests_distributed.py

Lines changed: 32 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
44

55
"""
6-
Simple wrapper to run pytest tests using torchrun, which manages distributed process groups and avoids port conflicts through automatic port allocation.
6+
Worker script for running pytest tests under torchrun.
7+
This script is invoked by torchrun and runs pytest within a distributed process group.
78
"""
89

910
import os
@@ -12,104 +13,33 @@
1213
# Set required environment variable for RCCL on ROCm
1314
os.environ.setdefault("HSA_NO_SCRATCH_RECLAIM", "1")
1415

15-
16-
def _distributed_worker_main():
17-
"""Main function for distributed worker that runs pytest."""
18-
import torch
19-
import torch.distributed as dist
20-
21-
# torchrun sets these environment variables automatically
22-
rank = int(os.environ.get("RANK", 0))
23-
world_size = int(os.environ.get("WORLD_SIZE", 1))
24-
local_rank = int(os.environ.get("LOCAL_RANK", 0))
25-
26-
# Set the correct GPU for this specific process
27-
if torch.cuda.is_available():
28-
torch.cuda.set_device(local_rank)
29-
30-
# Initialize distributed - torchrun already set up the environment
31-
dist.init_process_group(
32-
backend="nccl",
33-
rank=rank,
34-
world_size=world_size,
35-
device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None,
36-
)
37-
38-
try:
39-
# Import and run pytest directly
40-
import pytest
41-
42-
# Get pytest args from environment (set by launcher)
43-
pytest_args_str = os.environ.get("PYTEST_ARGS", "")
44-
pytest_args = pytest_args_str.split() if pytest_args_str else []
45-
46-
# Run pytest
47-
exit_code = pytest.main(pytest_args)
48-
sys.exit(exit_code)
49-
finally:
50-
if dist.is_initialized():
51-
dist.destroy_process_group()
52-
53-
54-
def main():
55-
if len(sys.argv) < 2:
56-
print("Usage: python run_tests_distributed.py [--num_ranks N] [pytest_args...] <test_file>")
57-
sys.exit(1)
58-
59-
# Check if we're being called as a torchrun worker
60-
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
61-
# We're running inside torchrun - execute as worker
62-
_distributed_worker_main()
63-
return
64-
65-
# We're the launcher - parse args and start torchrun
66-
num_ranks = 2
67-
args = sys.argv[1:]
68-
69-
if "--num_ranks" in args:
70-
idx = args.index("--num_ranks")
71-
if idx + 1 < len(args):
72-
num_ranks = int(args[idx + 1])
73-
# Remove --num_ranks and its value from args
74-
args = args[:idx] + args[idx + 2 :]
75-
76-
# The test file is the first argument after --num_ranks, everything else is pytest args
77-
if not args:
78-
print("Error: No test file specified")
79-
sys.exit(1)
80-
81-
test_file = args[0]
82-
pytest_args = args[1:] # Everything after the test file
83-
84-
print(f"Running {test_file} with {num_ranks} ranks using torchrun")
85-
86-
# Build pytest arguments string
87-
pytest_cmd_args = [test_file] + pytest_args
88-
pytest_args_str = " ".join(pytest_cmd_args)
89-
90-
# Set environment variable for worker to read
91-
os.environ["PYTEST_ARGS"] = pytest_args_str
92-
93-
# Build torchrun command - it will re-invoke this script as a worker
94-
import subprocess
95-
96-
torchrun_cmd = [
97-
"torchrun",
98-
f"--nproc_per_node={num_ranks}",
99-
"--standalone", # Single-node training
100-
__file__, # Re-invoke this script
101-
]
102-
103-
print(f"Executing: {' '.join(torchrun_cmd)}")
104-
105-
# Run torchrun and return its exit code
106-
try:
107-
result = subprocess.run(torchrun_cmd, check=False, env=os.environ.copy())
108-
sys.exit(result.returncode)
109-
except Exception as e:
110-
print(f"Error running torchrun: {e}")
111-
sys.exit(1)
112-
113-
114-
if __name__ == "__main__":
115-
main()
16+
import torch
17+
import torch.distributed as dist
18+
19+
# torchrun sets these environment variables automatically
20+
rank = int(os.environ.get("RANK", 0))
21+
world_size = int(os.environ.get("WORLD_SIZE", 1))
22+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
23+
24+
# Set the correct GPU for this specific process
25+
if torch.cuda.is_available():
26+
torch.cuda.set_device(local_rank)
27+
28+
# Initialize distributed - torchrun already set up the environment
29+
dist.init_process_group(
30+
backend="nccl",
31+
rank=rank,
32+
world_size=world_size,
33+
device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None,
34+
)
35+
36+
try:
37+
# Import and run pytest with command-line arguments
38+
import pytest
39+
40+
# Pass through all command-line arguments to pytest
41+
exit_code = pytest.main(sys.argv[1:])
42+
sys.exit(exit_code)
43+
finally:
44+
if dist.is_initialized():
45+
dist.destroy_process_group()

0 commit comments

Comments
 (0)