Skip to content

Commit 63e7e78

Browse files
committed
provisioner as actor
1 parent d464193 commit 63e7e78

File tree

3 files changed

+54
-34
lines changed

3 files changed

+54
-34
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ async def main(cfg: DictConfig):
346346
# TODO: support multiple host meshes
347347
trainer_num_procs = cfg.actors.trainer["procs"]
348348
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
349-
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
349+
trainer_hosts = provisioner.get_host_mesh.call_one(trainer_host_mesh_name)
350350
await ts.initialize(
351351
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
352352
strategy=ts.LocalRankStrategy(),

src/forge/controller/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66
from .actor import ForgeActor
77
from .provisioner import (
8+
get_or_create_provisioner,
89
get_proc_mesh,
910
host_mesh_from_proc,
10-
init_provisioner,
1111
shutdown,
1212
stop_proc_mesh,
1313
)
@@ -16,7 +16,7 @@
1616
"ForgeActor",
1717
"get_proc_mesh",
1818
"stop_proc_mesh",
19-
"init_provisioner",
19+
"get_or_create_provisioner",
2020
"shutdown",
2121
"host_mesh_from_proc",
2222
]

src/forge/controller/provisioner.py

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
from monarch._src.actor.actor_mesh import ActorMesh
1717
from monarch._src.actor.shape import Extent
1818

19-
from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host
19+
from monarch.actor import (
20+
Actor,
21+
endpoint,
22+
get_or_spawn_controller,
23+
HostMesh,
24+
ProcMesh,
25+
this_host,
26+
)
2027

2128
from monarch.tools import commands
2229

@@ -95,7 +102,7 @@ def release_gpus(self, gpu_ids: list[str]) -> None:
95102
self.available_gpus.add(int(gpu_id))
96103

97104

98-
class Provisioner:
105+
class Provisioner(Actor):
99106
"""A global resource provisioner."""
100107

101108
def __init__(self, cfg: ProvisionerConfig | None = None):
@@ -138,11 +145,13 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
138145
self._registered_actors: list["ForgeActor"] = []
139146
self._registered_services: list["ServiceInterface"] = []
140147

148+
@endpoint
141149
async def initialize(self):
142150
"""Call this after creating the instance"""
143151
if self.launcher is not None:
144152
await self.launcher.initialize()
145153

154+
@endpoint
146155
async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
147156
"""Creates a remote server and a HostMesh on it."""
148157
# no need to lock here because this is already locked behind `get_proc_mesh`
@@ -172,6 +181,7 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
172181
)
173182
return host_mesh, server_name
174183

184+
@endpoint
175185
def get_host_mesh(self, name: str) -> HostMesh:
176186
"""Returns the host mesh given its associated name.
177187
@@ -181,6 +191,7 @@ def get_host_mesh(self, name: str) -> HostMesh:
181191
"""
182192
return self._host_mesh_map[name]
183193

194+
@endpoint
184195
async def get_proc_mesh(
185196
self,
186197
num_procs: int,
@@ -225,7 +236,7 @@ async def get_proc_mesh(
225236
created_hosts = len(self._server_names)
226237
mesh_name = f"alloc_{created_hosts}"
227238
if host_mesh is None:
228-
host_mesh, server_name = await self.create_host_mesh(
239+
host_mesh, server_name = await self.create_host_mesh.call_one(
229240
name=mesh_name,
230241
num_hosts=num_hosts,
231242
)
@@ -318,13 +329,15 @@ def bootstrap(env: dict[str, str]):
318329
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
319330
return procs
320331

332+
@endpoint
321333
async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
322334
if proc_mesh not in self._proc_host_map:
323335
raise ValueError(
324336
"The proc mesh was not allocated with an associated hostmesh."
325337
)
326338
return self._proc_host_map[proc_mesh]
327339

340+
@endpoint
328341
async def stop_proc_mesh(self, proc_mesh: ProcMesh):
329342
"""Stops a proc mesh."""
330343
if proc_mesh not in self._proc_host_map:
@@ -352,6 +365,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
352365
commands.kill(server_name)
353366
del self._proc_host_map[proc_mesh]
354367

368+
@endpoint
355369
def register_service(self, service: "ServiceInterface") -> None:
356370
"""Registers a service allocation for cleanup."""
357371
# Import ServiceInterface here instead of at top-level to avoid circular import
@@ -364,6 +378,7 @@ def register_service(self, service: "ServiceInterface") -> None:
364378

365379
self._registered_services.append(service)
366380

381+
@endpoint
367382
def register_actor(self, actor: "ForgeActor") -> None:
368383
"""Registers a single actor allocation for cleanup."""
369384

@@ -372,13 +387,15 @@ def register_actor(self, actor: "ForgeActor") -> None:
372387

373388
self._registered_actors.append(actor)
374389

390+
@endpoint
375391
async def shutdown_all_allocations(self):
376392
"""Gracefully shut down all tracked actors and services."""
393+
global _global_registered_services
377394
logger.info(
378-
f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..."
395+
f"Shutting down {len(_global_registered_services)} service(s) and {len(self._registered_actors)} actor(s)..."
379396
)
380397
# --- ServiceInterface ---
381-
for service in reversed(self._registered_services):
398+
for service in reversed(_global_registered_services):
382399
try:
383400
await service.shutdown()
384401

@@ -398,29 +415,30 @@ async def shutdown_all_allocations(self):
398415
self._registered_actors.clear()
399416
self._registered_services.clear()
400417

418+
@endpoint
401419
async def shutdown(self):
402420
"""Tears down all remaining remote allocations."""
403-
await self.shutdown_all_allocations()
421+
await self.shutdown_all_allocations.call_one()
404422
async with self._lock:
405423
for server_name in self._server_names:
406424
commands.kill(server_name)
407425

408426

409-
_provisioner: Provisioner | None = None
410-
427+
_global_provisioner: Provisioner | None = None
428+
_global_registered_services: list["ServiceInterface"] = []
411429

412-
async def init_provisioner(cfg: ProvisionerConfig | None = None):
413-
global _provisioner
414-
if not _provisioner:
415-
_provisioner = Provisioner(cfg)
416-
await _provisioner.initialize()
417-
return _provisioner
418430

419-
420-
async def _get_provisioner():
421-
if not _provisioner:
422-
await init_provisioner()
423-
return _provisioner
431+
async def get_or_create_provisioner(
432+
cfg: ProvisionerConfig | None = None,
433+
) -> Provisioner:
434+
"""Gets or spawns the global Provisioner controller actor."""
435+
global _global_provisioner
436+
if _global_provisioner is None:
437+
_global_provisioner = await get_or_spawn_controller(
438+
"provisioner_controller", Provisioner, cfg
439+
)
440+
await _global_provisioner.initialize.call_one()
441+
return _global_provisioner
424442

425443

426444
async def get_proc_mesh(
@@ -445,8 +463,8 @@ async def get_proc_mesh(
445463
A proc mesh.
446464
447465
"""
448-
provisioner = await _get_provisioner()
449-
return await provisioner.get_proc_mesh(
466+
provisioner = await get_or_create_provisioner()
467+
return await provisioner.get_proc_mesh.call_one(
450468
num_procs=process_config.procs,
451469
with_gpus=process_config.with_gpus,
452470
num_hosts=process_config.hosts,
@@ -465,25 +483,27 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh):
465483
API.
466484
467485
"""
468-
provisioner = await _get_provisioner()
469-
return await provisioner.host_mesh_from_proc(proc_mesh)
486+
provisioner = await get_or_create_provisioner()
487+
return await provisioner.host_mesh_from_proc.call_one(proc_mesh)
470488

471489

472490
async def register_service(service: "ServiceInterface") -> None:
473491
"""Registers a service allocation with the global provisioner."""
474-
provisioner = await _get_provisioner()
475-
provisioner.register_service(service)
492+
493+
# TODO: This is a temporary hack. Change this back once Services are actors
494+
global _global_registered_services
495+
_global_registered_services.append(service)
476496

477497

478498
async def register_actor(actor: "ForgeActor") -> None:
479499
"""Registers an actor allocation with the global provisioner."""
480-
provisioner = await _get_provisioner()
481-
provisioner.register_actor(actor)
500+
provisioner = await get_or_create_provisioner()
501+
provisioner.register_actor.call_one(actor)
482502

483503

484504
async def stop_proc_mesh(proc_mesh: ProcMesh):
485-
provisioner = await _get_provisioner()
486-
return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh)
505+
provisioner = await get_or_create_provisioner()
506+
return await provisioner.stop_proc_mesh.call_one(proc_mesh=proc_mesh)
487507

488508

489509
async def shutdown_metric_logger():
@@ -504,8 +524,8 @@ async def shutdown():
504524

505525
logger.info("Shutting down provisioner..")
506526

507-
provisioner = await _get_provisioner()
508-
result = await provisioner.shutdown()
527+
provisioner = await get_or_create_provisioner()
528+
result = await provisioner.shutdown.call_one()
509529

510530
logger.info("Shutdown completed successfully")
511531
return result

0 commit comments

Comments
 (0)