4444
4545from monarch ._rust_bindings .monarch_hyperactor .shape import Shape , Slice
4646from monarch ._src .actor .actor_mesh import Actor
47- from monarch ._src .actor .proc_mesh import proc_mesh , ProcMesh
47+ from monarch ._src .actor .proc_mesh import ProcMesh
4848from monarch ._src .actor .shape import MeshTrait
4949
50+ from forge .controller .proc_mesh import get_proc_mesh
51+ from forge .types import ProcessConfig
52+
5053T = TypeVar ("T" , bound = Actor )
5154logger : logging .Logger = logging .getLogger (__name__ )
5255logger .setLevel (logging .INFO )
@@ -82,18 +85,19 @@ class RecoverableProcMesh(MeshTrait):
8285 services that need high availability.
8386
8487 Args:
85- num_gpus: Number of GPUs to allocate for the process mesh
88+ proc_config: ProcessConfig containing mesh configuration including num_procs
8689
8790 Attributes:
88- num_gpus : Number of GPUs allocated to this mesh
91+ num_procs : Number of processes allocated to this mesh
8992 state: Current state of the mesh (HEALTHY, RECOVERING, UNHEALTHY, STOPPED)
9093 healthy: True if the mesh is operational and ready for requests
9194 failed: True if the mesh has failed and needs recovery
9295
9396 Example:
9497 Basic usage with automatic recovery:
9598
96- >>> mesh = RecoverableProcMesh(num_gpus=2)
99+ >>> proc_config = ProcessConfig(num_procs=2, scheduler="local")
100+ >>> mesh = RecoverableProcMesh(proc_config)
97101 >>>
98102 >>> async def setup_actor(proc_mesh):
99103 ... actor = await proc_mesh.spawn("MyActor", MyActorClass)
@@ -104,7 +108,8 @@ class RecoverableProcMesh(MeshTrait):
104108
105109 Context manager for automatic cleanup:
106110
107- >>> async with RecoverableProcMesh(num_gpus=1) as mesh:
111+ >>> proc_config = ProcessConfig(num_procs=1)
112+ >>> async with RecoverableProcMesh(proc_config) as mesh:
108113 ... await mesh.spawn(setup_actor)
109114 ... # Use mesh for operations
110115 ... # Mesh automatically stopped and cleaned up on exit
@@ -121,9 +126,10 @@ class RecoverableProcMesh(MeshTrait):
121126
122127 def __init__ (
123128 self ,
124- num_procs : int ,
129+ proc_config : ProcessConfig ,
125130 ) -> None :
126- self .num_procs = num_procs
131+ self ._proc_config : ProcessConfig = proc_config
132+ self .num_procs = proc_config .num_procs
127133 self ._proc_mesh : Optional [ProcMesh ] = None
128134 self ._recovery_task : Optional [asyncio .Task [None ]] = None
129135 self .state : MeshState = MeshState .UNHEALTHY
@@ -185,7 +191,7 @@ async def _recover(
185191 logger .warning (f"Error stopping old ProcMesh: { e } " )
186192
187193 try :
188- self ._proc_mesh = await proc_mesh ( gpus = self .num_procs )
194+ self ._proc_mesh = await get_proc_mesh ( process_config = self ._proc_config )
189195 if self ._proc_mesh is not None :
190196 await hook (self ._proc_mesh )
191197 self .state = MeshState .HEALTHY
0 commit comments