|
3 | 3 | # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. |
4 | 4 |
|
5 | 5 | """ |
6 | | -Simple wrapper to run pytest tests using torchrun, which manages distributed process groups and avoids port conflicts through automatic port allocation. |
| 6 | +Worker script for running pytest tests under torchrun. |
| 7 | +This script is invoked by torchrun and runs pytest within a distributed process group. |
7 | 8 | """ |
8 | 9 |
|
9 | 10 | import os |
|
12 | 13 | # Set required environment variable for RCCL on ROCm |
13 | 14 | os.environ.setdefault("HSA_NO_SCRATCH_RECLAIM", "1") |
14 | 15 |
|
15 | | - |
16 | | -def _distributed_worker_main(): |
17 | | - """Main function for distributed worker that runs pytest.""" |
18 | | - import torch |
19 | | - import torch.distributed as dist |
20 | | - |
21 | | - # torchrun sets these environment variables automatically |
22 | | - rank = int(os.environ.get("RANK", 0)) |
23 | | - world_size = int(os.environ.get("WORLD_SIZE", 1)) |
24 | | - local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
25 | | - |
26 | | - # Set the correct GPU for this specific process |
27 | | - if torch.cuda.is_available(): |
28 | | - torch.cuda.set_device(local_rank) |
29 | | - |
30 | | - # Initialize distributed - torchrun already set up the environment |
31 | | - dist.init_process_group( |
32 | | - backend="nccl", |
33 | | - rank=rank, |
34 | | - world_size=world_size, |
35 | | - device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, |
36 | | - ) |
37 | | - |
38 | | - try: |
39 | | - # Import and run pytest directly |
40 | | - import pytest |
41 | | - |
42 | | - # Get pytest args from environment (set by launcher) |
43 | | - pytest_args_str = os.environ.get("PYTEST_ARGS", "") |
44 | | - pytest_args = pytest_args_str.split() if pytest_args_str else [] |
45 | | - |
46 | | - # Run pytest |
47 | | - exit_code = pytest.main(pytest_args) |
48 | | - sys.exit(exit_code) |
49 | | - finally: |
50 | | - if dist.is_initialized(): |
51 | | - dist.destroy_process_group() |
52 | | - |
53 | | - |
54 | | -def main(): |
55 | | - if len(sys.argv) < 2: |
56 | | - print("Usage: python run_tests_distributed.py [--num_ranks N] [pytest_args...] <test_file>") |
57 | | - sys.exit(1) |
58 | | - |
59 | | - # Check if we're being called as a torchrun worker |
60 | | - if "RANK" in os.environ and "WORLD_SIZE" in os.environ: |
61 | | - # We're running inside torchrun - execute as worker |
62 | | - _distributed_worker_main() |
63 | | - return |
64 | | - |
65 | | - # We're the launcher - parse args and start torchrun |
66 | | - num_ranks = 2 |
67 | | - args = sys.argv[1:] |
68 | | - |
69 | | - if "--num_ranks" in args: |
70 | | - idx = args.index("--num_ranks") |
71 | | - if idx + 1 < len(args): |
72 | | - num_ranks = int(args[idx + 1]) |
73 | | - # Remove --num_ranks and its value from args |
74 | | - args = args[:idx] + args[idx + 2 :] |
75 | | - |
76 | | - # The test file is the first argument after --num_ranks, everything else is pytest args |
77 | | - if not args: |
78 | | - print("Error: No test file specified") |
79 | | - sys.exit(1) |
80 | | - |
81 | | - test_file = args[0] |
82 | | - pytest_args = args[1:] # Everything after the test file |
83 | | - |
84 | | - print(f"Running {test_file} with {num_ranks} ranks using torchrun") |
85 | | - |
86 | | - # Build pytest arguments string |
87 | | - pytest_cmd_args = [test_file] + pytest_args |
88 | | - pytest_args_str = " ".join(pytest_cmd_args) |
89 | | - |
90 | | - # Set environment variable for worker to read |
91 | | - os.environ["PYTEST_ARGS"] = pytest_args_str |
92 | | - |
93 | | - # Build torchrun command - it will re-invoke this script as a worker |
94 | | - import subprocess |
95 | | - |
96 | | - torchrun_cmd = [ |
97 | | - "torchrun", |
98 | | - f"--nproc_per_node={num_ranks}", |
99 | | - "--standalone", # Single-node training |
100 | | - __file__, # Re-invoke this script |
101 | | - ] |
102 | | - |
103 | | - print(f"Executing: {' '.join(torchrun_cmd)}") |
104 | | - |
105 | | - # Run torchrun and return its exit code |
106 | | - try: |
107 | | - result = subprocess.run(torchrun_cmd, check=False, env=os.environ.copy()) |
108 | | - sys.exit(result.returncode) |
109 | | - except Exception as e: |
110 | | - print(f"Error running torchrun: {e}") |
111 | | - sys.exit(1) |
112 | | - |
113 | | - |
114 | | -if __name__ == "__main__": |
115 | | - main() |
| 16 | +import torch |
| 17 | +import torch.distributed as dist |
| 18 | + |
| 19 | +# torchrun sets these environment variables automatically |
| 20 | +rank = int(os.environ.get("RANK", 0)) |
| 21 | +world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| 22 | +local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| 23 | + |
| 24 | +# Set the correct GPU for this specific process |
| 25 | +if torch.cuda.is_available(): |
| 26 | + torch.cuda.set_device(local_rank) |
| 27 | + |
| 28 | +# Initialize distributed - torchrun already set up the environment |
| 29 | +dist.init_process_group( |
| 30 | + backend="nccl", |
| 31 | + rank=rank, |
| 32 | + world_size=world_size, |
| 33 | + device_id=torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else None, |
| 34 | +) |
| 35 | + |
| 36 | +try: |
| 37 | + # Import and run pytest with command-line arguments |
| 38 | + import pytest |
| 39 | + |
| 40 | + # Pass through all command-line arguments to pytest |
| 41 | + exit_code = pytest.main(sys.argv[1:]) |
| 42 | + sys.exit(exit_code) |
| 43 | +finally: |
| 44 | + if dist.is_initialized(): |
| 45 | + dist.destroy_process_group() |
0 commit comments