Skip to content

Commit 0d86081

Browse files
authored
Merge branch 'master' into loadams/reenable-py311-312
2 parents 28f0bbb + 4809072 commit 0d86081

File tree

2 files changed

+3
-84
lines changed

2 files changed

+3
-84
lines changed

deepspeed/launcher/multinode_runner.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -141,30 +141,12 @@ def name(self):
141141
def validate_args(self):
142142
super().validate_args()
143143

144-
# Validate and set MPI environment variables
145-
self._setup_mpi_environment()
146-
147144
#TODO: Allow for include/exclude at node-level but not gpu-level
148145
if self.args.include != "" or self.args.exclude != "":
149146
raise ValueError(f"{self.name} backend does not support worker include/exclusion")
150147
if self.args.num_nodes != -1 or self.args.num_gpus != -1:
151148
raise ValueError(f"{self.name} backend does not support limiting num nodes/gpus")
152149

153-
def _setup_mpi_environment(self):
154-
"""Sets up MPI-related environment variables or raises an error if they're missing."""
155-
156-
required_vars = ['OMPI_COMM_WORLD_LOCAL_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_COMM_WORLD_SIZE']
157-
158-
# Check if all these are present
159-
if not all(var in os.environ for var in required_vars):
160-
raise EnvironmentError("MPI environment variables are not set. "
161-
"Ensure you are running the script with an MPI-compatible launcher.")
162-
163-
# Now safe to read all
164-
os.environ['LOCAL_RANK'] = os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
165-
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
166-
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
167-
168150
def get_cmd(self, environment, active_resources):
169151
total_process_count = sum(self.resource_pool.values())
170152

tests/unit/launcher/test_multinode_runner.py

Lines changed: 3 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@ def runner_info():
1919
return env, hosts, world_info, args
2020

2121

22-
@pytest.fixture
23-
def mock_mpi_env(monkeypatch):
24-
# Provide the 3 required MPI variables:
25-
monkeypatch.setenv('OMPI_COMM_WORLD_LOCAL_RANK', '0')
26-
monkeypatch.setenv('OMPI_COMM_WORLD_RANK', '0')
27-
monkeypatch.setenv('OMPI_COMM_WORLD_SIZE', '1')
28-
29-
3022
def test_pdsh_runner(runner_info):
3123
env, resource_pool, world_info, args = runner_info
3224
runner = mnrunner.PDSHRunner(args, world_info)
@@ -35,15 +27,15 @@ def test_pdsh_runner(runner_info):
3527
assert env['PDSH_RCMD_TYPE'] == 'ssh'
3628

3729

38-
def test_openmpi_runner(runner_info, mock_mpi_env):
30+
def test_openmpi_runner(runner_info):
3931
env, resource_pool, world_info, args = runner_info
4032
runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
4133
cmd = runner.get_cmd(env, resource_pool)
4234
assert cmd[0] == 'mpirun'
4335
assert 'eth0' in cmd
4436

4537

46-
def test_btl_nic_openmpi_runner(runner_info, mock_mpi_env):
38+
def test_btl_nic_openmpi_runner(runner_info):
4739
env, resource_pool, world_info, _ = runner_info
4840
args = parse_args(['--launcher_arg', '-mca btl_tcp_if_include eth1', 'test_launcher.py'])
4941
runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
@@ -52,7 +44,7 @@ def test_btl_nic_openmpi_runner(runner_info, mock_mpi_env):
5244
assert 'eth1' in cmd
5345

5446

55-
def test_btl_nic_two_dashes_openmpi_runner(runner_info, mock_mpi_env):
47+
def test_btl_nic_two_dashes_openmpi_runner(runner_info):
5648
env, resource_pool, world_info, _ = runner_info
5749
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])
5850
runner = mnrunner.OpenMPIRunner(args, world_info, resource_pool)
@@ -61,61 +53,6 @@ def test_btl_nic_two_dashes_openmpi_runner(runner_info, mock_mpi_env):
6153
assert 'eth1' in cmd
6254

6355

64-
def test_setup_mpi_environment_success():
65-
"""Test that _setup_mpi_environment correctly sets environment variables when MPI variables exist."""
66-
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
67-
os.environ['OMPI_COMM_WORLD_RANK'] = '1'
68-
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'
69-
70-
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])
71-
72-
runner = mnrunner.OpenMPIRunner(args, None, None)
73-
# Set up the MPI environment
74-
runner._setup_mpi_environment()
75-
76-
assert os.environ['LOCAL_RANK'] == '0'
77-
assert os.environ['RANK'] == '1'
78-
assert os.environ['WORLD_SIZE'] == '2'
79-
80-
# Clean up environment
81-
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
82-
del os.environ['OMPI_COMM_WORLD_RANK']
83-
del os.environ['OMPI_COMM_WORLD_SIZE']
84-
del os.environ['LOCAL_RANK']
85-
del os.environ['RANK']
86-
del os.environ['WORLD_SIZE']
87-
88-
89-
def test_setup_mpi_environment_missing_variables():
90-
"""Test that _setup_mpi_environment raises an EnvironmentError when MPI variables are missing."""
91-
92-
# Clear relevant environment variables
93-
os.environ.pop('OMPI_COMM_WORLD_LOCAL_RANK', None)
94-
os.environ.pop('OMPI_COMM_WORLD_RANK', None)
95-
os.environ.pop('OMPI_COMM_WORLD_SIZE', None)
96-
97-
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])
98-
99-
with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
100-
mnrunner.OpenMPIRunner(args, None, None)
101-
102-
103-
def test_setup_mpi_environment_fail():
104-
"""Test that _setup_mpi_environment fails if only partial MPI variables are provided."""
105-
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = '0'
106-
os.environ.pop('OMPI_COMM_WORLD_RANK', None) # missing variable
107-
os.environ['OMPI_COMM_WORLD_SIZE'] = '2'
108-
109-
args = parse_args(['--launcher_arg', '--mca btl_tcp_if_include eth1', 'test_launcher.py'])
110-
111-
with pytest.raises(EnvironmentError, match="MPI environment variables are not set"):
112-
runner = mnrunner.OpenMPIRunner(args, None, None)
113-
114-
# Clean up environment
115-
del os.environ['OMPI_COMM_WORLD_LOCAL_RANK']
116-
del os.environ['OMPI_COMM_WORLD_SIZE']
117-
118-
11956
def test_mpich_runner(runner_info):
12057
env, resource_pool, world_info, args = runner_info
12158
runner = mnrunner.MPICHRunner(args, world_info, resource_pool)

0 commit comments

Comments
 (0)