Skip to content

Commit 261aff7

Browse files
committed
more cleanup
1 parent 6be0e93 commit 261aff7

File tree

5 files changed

+73
-13
lines changed

5 files changed

+73
-13
lines changed

src/forge/controller/launcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,9 @@ def create_server_handle(self) -> str:
293293

294294

295295
def get_launcher(cfg: LauncherConfig | None = None) -> BaseLauncher | None:
296-
if not cfg or cfg.launcher == Launcher.SLURM:
296+
if not cfg:
297+
return None
298+
if cfg.launcher == Launcher.SLURM:
297299
return Slurmlauncher()
298300
elif cfg.launcher == Launcher.MAST:
299301
return Mastlauncher(cfg)

src/forge/controller/provisioner.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Remote resource allocation and provisioning."""
7+
"""Resource allocation and provisioning for both local and remote."""
88
import asyncio
99
import functools
1010
import logging
@@ -160,20 +160,40 @@ async def get_proc_mesh(
160160
mesh_name: Optional[str] = None,
161161
host_mesh: HostMesh | None = None,
162162
env_vars: dict[str, str] | None = None,
163+
addr: str | None = None,
164+
port: str | None = None,
163165
):
164166
"""Gets a proc mesh.
165167
166-
num_hosts = None implies that you want a local allocation, this may change.
168+
Args:
169+
num_procs: The number of processes to allocate.
170+
with_gpus: Whether to include GPU allocations.
171+
This only adds the CUDA_VISIBLE_DEVICES environment variable.
172+
num_hosts: The number of hosts to allocate.
173+
If this is set, a remote allocation is created.
174+
If this is None, it uses the local host.
175+
This behavior may change in the future.
176+
host_mesh: The host mesh to allocate the process on.
177+
If None, a new host mesh will be created.
178+
port: The distributed port to use.
179+
If None, a port will be detected.
180+
addr: The distributed address to use.
181+
If None, an address will be detected.
182+
183+
Returns:
184+
A proc mesh.
167185
168186
"""
169187
if env_vars is None:
170188
env_vars = {}
171189

190+
is_remote = num_hosts is not None and num_hosts > 0
191+
172192
async with self._lock:
173193
server_name = None
174-
if num_hosts is not None and num_hosts > 0:
175-
created_hosts = len(self._server_names)
194+
if is_remote:
176195
if mesh_name is None:
196+
created_hosts = len(self._server_names)
177197
mesh_name = f"alloc_{created_hosts}"
178198
if host_mesh is None:
179199
host_mesh, server_name = await self.create_host_mesh(
@@ -188,18 +208,22 @@ async def get_proc_mesh(
188208
host_id = host_mesh._host_id
189209
gpu_manager = self._host_gpu_map[host_id]
190210
else:
211+
# fallback to local
191212
host_mesh = this_host()
192213
gpu_manager = self._host_gpu_map[self._this_host_id]
193214
host_mesh._host_id = self._this_host_id
194215

195216
def bootstrap(env: dict[str, str]):
217+
# bootstrap is run on all processes. We use this
218+
# to set environment variables like CUDA etc.
196219
import os
197220

198221
for k, v in env.items():
199222
os.environ[k] = v
200223

201224
if with_gpus:
202-
addr, port = await get_remote_info(host_mesh)
225+
if not addr or not port:
226+
addr, port = await get_remote_info(host_mesh)
203227
gpu_ids = gpu_manager.get_gpus(num_procs)
204228

205229
env_vars["MASTER_ADDR"] = addr
@@ -213,7 +237,9 @@ def bootstrap(env: dict[str, str]):
213237
per_host={"gpus": num_procs},
214238
bootstrap=functools.partial(bootstrap, env=env_vars),
215239
)
216-
await self.launcher.remote_setup(procs)
240+
241+
if is_remote:
242+
await self.launcher.remote_setup(procs)
217243

218244
# Tag the proc mesh with additional metadata for our own cleanup later
219245
if with_gpus:
@@ -284,8 +310,24 @@ async def get_proc_mesh(
284310
process_config: ProcessConfig,
285311
host_mesh: HostMesh | None = None,
286312
env_vars: dict[str, str] | None = None,
313+
port: str | None = None,
314+
addr: str | None = None,
287315
) -> ProcMesh:
288-
"""Returns a proc mesh from the provisioner."""
316+
"""Returns a proc mesh from the provisioner.
317+
318+
Args:
319+
process_config: The process config.
320+
host_mesh: The host mesh to allocate the process on.
321+
If None, a new host mesh will be created.
322+
port: The distributed port to use.
323+
If None, a port will be detected.
324+
addr: The distributed address to use.
325+
If None, an address will be detected.
326+
327+
Returns:
328+
A proc mesh.
329+
330+
"""
289331
provisioner = await _get_provisioner()
290332
return await provisioner.get_proc_mesh(
291333
num_procs=process_config.procs,
@@ -294,6 +336,8 @@ async def get_proc_mesh(
294336
mesh_name=process_config.mesh_name,
295337
host_mesh=host_mesh,
296338
env_vars=env_vars,
339+
port=port,
340+
addr=addr,
297341
)
298342

299343

src/forge/types.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,17 @@ class Launcher(Enum):
9595

9696
@dataclass
9797
class ProcessConfig:
98-
"""A proc_mesh config for the torchx scheduler."""
98+
"""A configuration for allocating Monarch ProcMeshes.
99+
100+
Args:
101+
procs (int): Number of processes to launch for each replica of the service.
102+
with_gpus (bool, optional): Whether to allocate GPUs for the service processes.
103+
hosts (int | None, optional): Number of hosts to allocate for each replica.
104+
If this is set to None, it will use the local host.
105+
If this is set to a positive integer, it will run on a remote host.
106+
mesh_name (str | None, optional): Name of the mesh to use for the proc_mesh.
107+
108+
"""
99109

100110
procs: int = 1
101111
with_gpus: bool = False
@@ -105,13 +115,15 @@ class ProcessConfig:
105115

106116
@dataclass
107117
class ServiceConfig:
108-
"""
109-
A service config.
118+
"""The configuration for a Forge service.
119+
110120
Args:
111121
procs (int): Number of processes to launch for each replica of the service.
112122
num_replicas (int): Number of replicas to launch for the service.
113123
with_gpus (bool, optional): Whether to allocate GPUs for the service processes.
114124
hosts (int | None, optional): Number of hosts to allocate for each replica.
125+
If this is set to None, it will use the local host.
126+
If this is set to a positive integer, it will run on a remote host.
115127
health_poll_rate (float, optional): Frequency (in seconds) to poll for health status.
116128
replica_max_concurrent_requests (int, optional): Maximum number of concurrent requests per replica.
117129
return_first_rank_result (bool, optional): Whether to auto-unwrap ValueMesh to the first rank's result.
@@ -121,14 +133,14 @@ class ServiceConfig:
121133
num_replicas: int
122134
with_gpus: bool = False
123135
hosts: int | None = None
124-
# ServiceConfig-specific fields
125136
health_poll_rate: float = 0.2
126137
replica_max_concurrent_requests: int = 10
127138
return_first_rank_result: bool = True
128139
mesh_name: str | None = None
129140

130141
def to_process_config(self) -> ProcessConfig:
131142
"""Extract ProcessConfig from this ServiceConfig.
143+
132144
Maps procs to procs for ProcessConfig.
133145
"""
134146
return ProcessConfig(

tests/unit_tests/test_provisioner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ async def test_get_proc_mesh_respects_cuda_visible_devices(self):
161161
num_procs=2,
162162
with_gpus=True,
163163
num_hosts=None,
164+
port="12345",
165+
addr="localhost",
164166
)
165167
# Verify GPUs were allocated from available set
166168
remaining_available = local_gpu_manager.get_available_gpus()

tests/unit_tests/test_replay_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class TestReplayBuffer:
1616
@pytest_asyncio.fixture
1717
async def replay_buffer(self) -> ReplayBuffer:
18-
replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor(
18+
replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=False).as_actor(
1919
batch_size=2, max_policy_age=1
2020
)
2121
await replay_buffer.setup.call()

0 commit comments

Comments
 (0)