Skip to content

Commit 5d0d7a8

Browse files
allenwang28Allen Wang
andauthored
Modifies service.py to use get_proc_mesh, and unifies the ProcConfig with ServiceConfig (#63)
* [service] Removes autoscaling! * changes recoverable proc mesh to use get_prc_mesh --------- Co-authored-by: Allen Wang <[email protected]>
1 parent 94dfaa5 commit 5d0d7a8

File tree

3 files changed

+48
-22
lines changed

3 files changed

+48
-22
lines changed

src/forge/controller/recoverable_mesh.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@
4444

4545
from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice
4646
from monarch._src.actor.actor_mesh import Actor
47-
from monarch._src.actor.proc_mesh import proc_mesh, ProcMesh
47+
from monarch._src.actor.proc_mesh import ProcMesh
4848
from monarch._src.actor.shape import MeshTrait
4949

50+
from forge.controller.proc_mesh import get_proc_mesh
51+
from forge.types import ProcessConfig
52+
5053
T = TypeVar("T", bound=Actor)
5154
logger: logging.Logger = logging.getLogger(__name__)
5255
logger.setLevel(logging.INFO)
@@ -82,18 +85,19 @@ class RecoverableProcMesh(MeshTrait):
8285
services that need high availability.
8386
8487
Args:
85-
num_gpus: Number of GPUs to allocate for the process mesh
88+
proc_config: ProcessConfig containing mesh configuration including num_procs
8689
8790
Attributes:
88-
num_gpus: Number of GPUs allocated to this mesh
91+
num_procs: Number of processes allocated to this mesh
8992
state: Current state of the mesh (HEALTHY, RECOVERING, UNHEALTHY, STOPPED)
9093
healthy: True if the mesh is operational and ready for requests
9194
failed: True if the mesh has failed and needs recovery
9295
9396
Example:
9497
Basic usage with automatic recovery:
9598
96-
>>> mesh = RecoverableProcMesh(num_gpus=2)
99+
>>> proc_config = ProcessConfig(num_procs=2, scheduler="local")
100+
>>> mesh = RecoverableProcMesh(proc_config)
97101
>>>
98102
>>> async def setup_actor(proc_mesh):
99103
... actor = await proc_mesh.spawn("MyActor", MyActorClass)
@@ -104,7 +108,8 @@ class RecoverableProcMesh(MeshTrait):
104108
105109
Context manager for automatic cleanup:
106110
107-
>>> async with RecoverableProcMesh(num_gpus=1) as mesh:
111+
>>> proc_config = ProcessConfig(num_procs=1)
112+
>>> async with RecoverableProcMesh(proc_config) as mesh:
108113
... await mesh.spawn(setup_actor)
109114
... # Use mesh for operations
110115
... # Mesh automatically stopped and cleaned up on exit
@@ -121,9 +126,10 @@ class RecoverableProcMesh(MeshTrait):
121126

122127
def __init__(
123128
self,
124-
num_procs: int,
129+
proc_config: ProcessConfig,
125130
) -> None:
126-
self.num_procs = num_procs
131+
self._proc_config: ProcessConfig = proc_config
132+
self.num_procs = proc_config.num_procs
127133
self._proc_mesh: Optional[ProcMesh] = None
128134
self._recovery_task: Optional[asyncio.Task[None]] = None
129135
self.state: MeshState = MeshState.UNHEALTHY
@@ -185,7 +191,7 @@ async def _recover(
185191
logger.warning(f"Error stopping old ProcMesh: {e}")
186192

187193
try:
188-
self._proc_mesh = await proc_mesh(gpus=self.num_procs)
194+
self._proc_mesh = await get_proc_mesh(process_config=self._proc_config)
189195
if self._proc_mesh is not None:
190196
await hook(self._proc_mesh)
191197
self.state = MeshState.HEALTHY

src/forge/controller/service.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from monarch.actor import ActorError, ProcMesh
4848

4949
from forge.controller import RecoverableProcMesh
50+
from forge.types import ServiceConfig
5051

5152
logger = logging.getLogger(__name__)
5253
logger.setLevel(logging.DEBUG)
@@ -187,17 +188,6 @@ def get_sessions_per_replica(self) -> float:
187188
return self.total_sessions / self.healthy_replicas
188189

189190

190-
@dataclass
191-
class ServiceConfig:
192-
procs_per_replica: int
193-
num_replicas: int
194-
health_poll_rate: float = 0.2
195-
replica_max_concurrent_requests: int = 10
196-
return_first_rank_result: bool = (
197-
True # Auto-unwrap ValueMesh to first rank's result
198-
)
199-
200-
201191
@dataclass
202192
class Replica:
203193
proc_mesh: RecoverableProcMesh
@@ -335,9 +325,7 @@ async def __initialize__(self):
335325
replicas = []
336326
num_replicas = self._cfg.num_replicas
337327
for i in range(num_replicas):
338-
mesh = RecoverableProcMesh(
339-
self._cfg.procs_per_replica,
340-
)
328+
mesh = RecoverableProcMesh(proc_config=self._cfg.to_process_config())
341329
replica = Replica(
342330
proc_mesh=mesh,
343331
actor=None,

src/forge/types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,35 @@ class ProcessConfig:
9898
oncall: str = "torchtune"
9999
identity: str = "pytorch_distributed"
100100
image: str = "forge_workspace:latest"
101+
102+
103+
@dataclass
104+
class ServiceConfig:
105+
"""A service config."""
106+
107+
procs_per_replica: int
108+
num_replicas: int
109+
num_hosts: int = 1
110+
scheduler: Literal["mast", "local"] = "local"
111+
oncall: str = "torchtune"
112+
identity: str = "pytorch_distributed"
113+
image: str = "forge_workspace:latest"
114+
# ServiceConfig-specific fields
115+
health_poll_rate: float = 0.2
116+
replica_max_concurrent_requests: int = 10
117+
return_first_rank_result: bool = (
118+
True # Whether or not to auto-unwrap ValueMesh to first rank's result
119+
)
120+
121+
def to_process_config(self) -> ProcessConfig:
122+
"""Extract ProcessConfig from this ServiceConfig.
123+
Maps procs_per_replica to num_procs for ProcessConfig.
124+
"""
125+
return ProcessConfig(
126+
scheduler=self.scheduler,
127+
num_procs=self.procs_per_replica,
128+
num_hosts=self.num_hosts,
129+
oncall=self.oncall,
130+
identity=self.identity,
131+
image=self.image,
132+
)

0 commit comments

Comments
 (0)