1616from monarch ._src .actor .actor_mesh import ActorMesh
1717from monarch ._src .actor .shape import Extent
1818
19- from monarch .actor import Actor , endpoint , HostMesh , ProcMesh , this_host
19+ from monarch .actor import (
20+ Actor ,
21+ endpoint ,
22+ get_or_spawn_controller ,
23+ HostMesh ,
24+ ProcMesh ,
25+ this_host ,
26+ )
2027
2128from monarch .tools import commands
2229
@@ -95,7 +102,7 @@ def release_gpus(self, gpu_ids: list[str]) -> None:
95102 self .available_gpus .add (int (gpu_id ))
96103
97104
98- class Provisioner :
105+ class Provisioner ( Actor ) :
99106 """A global resource provisioner."""
100107
101108 def __init__ (self , cfg : ProvisionerConfig | None = None ):
@@ -138,11 +145,13 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
138145 self ._registered_actors : list ["ForgeActor" ] = []
139146 self ._registered_services : list ["ServiceInterface" ] = []
140147
148+ @endpoint
141149 async def initialize (self ):
142150 """Call this after creating the instance"""
143151 if self .launcher is not None :
144152 await self .launcher .initialize ()
145153
154+ @endpoint
146155 async def create_host_mesh (self , name : str , num_hosts : int ) -> HostMesh :
147156 """Creates a remote server and a HostMesh on it."""
148157 # no need to lock here because this is already locked behind `get_proc_mesh`
@@ -172,6 +181,7 @@ async def create_host_mesh(self, name: str, num_hosts: int) -> HostMesh:
172181 )
173182 return host_mesh , server_name
174183
184+ @endpoint
175185 def get_host_mesh (self , name : str ) -> HostMesh :
176186 """Returns the host mesh given its associated name.
177187
@@ -181,6 +191,7 @@ def get_host_mesh(self, name: str) -> HostMesh:
181191 """
182192 return self ._host_mesh_map [name ]
183193
194+ @endpoint
184195 async def get_proc_mesh (
185196 self ,
186197 num_procs : int ,
@@ -225,7 +236,7 @@ async def get_proc_mesh(
225236 created_hosts = len (self ._server_names )
226237 mesh_name = f"alloc_{ created_hosts } "
227238 if host_mesh is None :
228- host_mesh , server_name = await self .create_host_mesh (
239+ host_mesh , server_name = await self .create_host_mesh . call_one (
229240 name = mesh_name ,
230241 num_hosts = num_hosts ,
231242 )
@@ -318,13 +329,15 @@ def bootstrap(env: dict[str, str]):
318329 _ = await get_or_create_metric_logger (procs , process_name = mesh_name )
319330 return procs
320331
332+ @endpoint
321333 async def host_mesh_from_proc (self , proc_mesh : ProcMesh ):
322334 if proc_mesh not in self ._proc_host_map :
323335 raise ValueError (
324336 "The proc mesh was not allocated with an associated hostmesh."
325337 )
326338 return self ._proc_host_map [proc_mesh ]
327339
340+ @endpoint
328341 async def stop_proc_mesh (self , proc_mesh : ProcMesh ):
329342 """Stops a proc mesh."""
330343 if proc_mesh not in self ._proc_host_map :
@@ -352,6 +365,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
352365 commands .kill (server_name )
353366 del self ._proc_host_map [proc_mesh ]
354367
368+ @endpoint
355369 def register_service (self , service : "ServiceInterface" ) -> None :
356370 """Registers a service allocation for cleanup."""
357371 # Import ServiceInterface here instead of at top-level to avoid circular import
@@ -364,6 +378,7 @@ def register_service(self, service: "ServiceInterface") -> None:
364378
365379 self ._registered_services .append (service )
366380
381+ @endpoint
367382 def register_actor (self , actor : "ForgeActor" ) -> None :
368383 """Registers a single actor allocation for cleanup."""
369384
@@ -372,13 +387,15 @@ def register_actor(self, actor: "ForgeActor") -> None:
372387
373388 self ._registered_actors .append (actor )
374389
390+ @endpoint
375391 async def shutdown_all_allocations (self ):
376392 """Gracefully shut down all tracked actors and services."""
393+ global _global_registered_services
377394 logger .info (
378- f"Shutting down { len (self . _registered_services )} service(s) and { len (self ._registered_actors )} actor(s)..."
395+ f"Shutting down { len (_global_registered_services )} service(s) and { len (self ._registered_actors )} actor(s)..."
379396 )
380397 # --- ServiceInterface ---
381- for service in reversed (self . _registered_services ):
398+ for service in reversed (_global_registered_services ):
382399 try :
383400 await service .shutdown ()
384401
@@ -398,29 +415,30 @@ async def shutdown_all_allocations(self):
398415 self ._registered_actors .clear ()
399416 self ._registered_services .clear ()
400417
418+ @endpoint
401419 async def shutdown (self ):
402420 """Tears down all remaining remote allocations."""
403- await self .shutdown_all_allocations ()
421+ await self .shutdown_all_allocations . call_one ()
404422 async with self ._lock :
405423 for server_name in self ._server_names :
406424 commands .kill (server_name )
407425
408426
409- _provisioner : Provisioner | None = None
410-
427+ _global_provisioner : Provisioner | None = None
428+ _global_registered_services : list [ "ServiceInterface" ] = []
411429
412- async def init_provisioner (cfg : ProvisionerConfig | None = None ):
413- global _provisioner
414- if not _provisioner :
415- _provisioner = Provisioner (cfg )
416- await _provisioner .initialize ()
417- return _provisioner
418430
419-
420- async def _get_provisioner ():
421- if not _provisioner :
422- await init_provisioner ()
423- return _provisioner
431+ async def get_or_create_provisioner (
432+ cfg : ProvisionerConfig | None = None ,
433+ ) -> Provisioner :
434+ """Gets or spawns the global Provisioner controller actor."""
435+ global _global_provisioner
436+ if _global_provisioner is None :
437+ _global_provisioner = await get_or_spawn_controller (
438+ "provisioner_controller" , Provisioner , cfg
439+ )
440+ await _global_provisioner .initialize .call_one ()
441+ return _global_provisioner
424442
425443
426444async def get_proc_mesh (
@@ -445,8 +463,8 @@ async def get_proc_mesh(
445463 A proc mesh.
446464
447465 """
448- provisioner = await _get_provisioner ()
449- return await provisioner .get_proc_mesh (
466+ provisioner = await get_or_create_provisioner ()
467+ return await provisioner .get_proc_mesh . call_one (
450468 num_procs = process_config .procs ,
451469 with_gpus = process_config .with_gpus ,
452470 num_hosts = process_config .hosts ,
@@ -465,25 +483,27 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh):
465483 API.
466484
467485 """
468- provisioner = await _get_provisioner ()
469- return await provisioner .host_mesh_from_proc (proc_mesh )
486+ provisioner = await get_or_create_provisioner ()
487+ return await provisioner .host_mesh_from_proc . call_one (proc_mesh )
470488
471489
472490async def register_service (service : "ServiceInterface" ) -> None :
473491 """Registers a service allocation with the global provisioner."""
474- provisioner = await _get_provisioner ()
475- provisioner .register_service (service )
492+
493+ # TODO: This is a temporary hack. Change this back once Services are actors
494+ global _global_registered_services
495+ _global_registered_services .append (service )
476496
477497
478498async def register_actor (actor : "ForgeActor" ) -> None :
479499 """Registers an actor allocation with the global provisioner."""
480- provisioner = await _get_provisioner ()
481- provisioner .register_actor (actor )
500+ provisioner = await get_or_create_provisioner ()
501+ provisioner .register_actor . call_one (actor )
482502
483503
484504async def stop_proc_mesh (proc_mesh : ProcMesh ):
485- provisioner = await _get_provisioner ()
486- return await provisioner .stop_proc_mesh (proc_mesh = proc_mesh )
505+ provisioner = await get_or_create_provisioner ()
506+ return await provisioner .stop_proc_mesh . call_one (proc_mesh = proc_mesh )
487507
488508
489509async def shutdown_metric_logger ():
@@ -504,8 +524,8 @@ async def shutdown():
504524
505525 logger .info ("Shutting down provisioner.." )
506526
507- provisioner = await _get_provisioner ()
508- result = await provisioner .shutdown ()
527+ provisioner = await get_or_create_provisioner ()
528+ result = await provisioner .shutdown . call_one ()
509529
510530 logger .info ("Shutdown completed successfully" )
511531 return result
0 commit comments