Skip to content

Commit dd2cfa8

Browse files
author
Allen Wang
committed
initial commit for replica
1 parent 5d0d7a8 commit dd2cfa8

File tree

4 files changed

+1110
-2
lines changed

4 files changed

+1110
-2
lines changed

src/forge/controller/replica.py

Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
"""Replica for distributed actor service."""
7+
8+
import asyncio
9+
import logging
10+
from dataclasses import dataclass, field
11+
from enum import Enum
12+
from typing import Optional
13+
14+
from monarch.actor import Actor, ActorError
15+
16+
from forge.controller import RecoverableProcMesh
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class ReplicaState(Enum):
22+
HEALTHY = "healthy"
23+
RECOVERING = "recovering"
24+
UNHEALTHY = "unhealthy"
25+
STOPPED = "stopped"
26+
UNINITIALIZED = "uninitialized"
27+
28+
29+
@dataclass
30+
class ServiceRequest:
31+
session_id: Optional[str]
32+
function: str
33+
args: tuple
34+
kwargs: dict
35+
future: asyncio.Future
36+
37+
38+
@dataclass
39+
class Replica:
40+
proc_mesh: RecoverableProcMesh
41+
actor: Optional[Actor]
42+
idx: int
43+
request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue)
44+
active_requests: int = 0
45+
max_concurrent_requests: int = 10
46+
_running: bool = False
47+
metadata: dict = field(default_factory=dict)
48+
state: ReplicaState = ReplicaState.UNINITIALIZED
49+
return_first_rank_result: bool = False
50+
51+
async def enqueue_request(self, request: ServiceRequest):
52+
"""Enqueues a request for processing by this replica."""
53+
if self.state == ReplicaState.STOPPED:
54+
raise RuntimeError(f"Replica {self.idx} is stopped")
55+
56+
# Accept requests in all other states - let the processing loop handle the rest
57+
await self.request_queue.put(request)
58+
59+
async def _process_single_request(self, request: ServiceRequest) -> bool:
60+
"""
61+
Processes a single request and returns success status.
62+
63+
Returns:
64+
bool: True if request succeeded, False if it failed
65+
"""
66+
self.active_requests += 1
67+
68+
try:
69+
# Get the actor and endpoint
70+
actor = self.actor
71+
endpoint_func = getattr(actor, request.function)
72+
73+
# Execute the request
74+
success = True
75+
try:
76+
result = await endpoint_func.call(*request.args, **request.kwargs)
77+
# Unwrap ValueMesh if configured to return first rank result
78+
if (
79+
self.return_first_rank_result
80+
and hasattr(result, "_values")
81+
and result._values
82+
):
83+
result = result._values[0]
84+
request.future.set_result(result)
85+
except ActorError as e:
86+
logger.debug("Got failure on replica %d. Error:\n%s", self.idx, e)
87+
# Mark proc_mesh as failed and transition state
88+
self.proc_mesh.mark_failed()
89+
self.state = ReplicaState.RECOVERING
90+
# Unwrap the ActorError into its raw exception
91+
request.future.set_exception(e.exception)
92+
success = False
93+
except Exception as e:
94+
logger.debug(
95+
"Got unexpected error on replica %d. Error:\n%s", self.idx, e
96+
)
97+
# Mark proc_mesh as failed and transition state
98+
self.proc_mesh.mark_failed()
99+
self.state = ReplicaState.RECOVERING
100+
request.future.set_exception(e)
101+
success = False
102+
103+
# Mark task as done
104+
self.request_queue.task_done()
105+
return success
106+
107+
finally:
108+
self.active_requests -= 1
109+
110+
async def run(self):
111+
"""
112+
Main processing loop for the replica. This replaces _persistent_processor.
113+
114+
Continuously processes requests from the queue while the replica is healthy.
115+
Handles capacity management and graceful degradation on failures.
116+
"""
117+
self._running = True
118+
119+
try:
120+
while self.state in (ReplicaState.HEALTHY, ReplicaState.RECOVERING):
121+
try:
122+
# Wait for a request with timeout to check health periodically
123+
request = await asyncio.wait_for(
124+
self.request_queue.get(), timeout=1.0
125+
)
126+
127+
# Check if we have capacity
128+
if self.active_requests >= self.max_concurrent_requests:
129+
# Put the request back and wait
130+
await self.request_queue.put(request)
131+
await asyncio.sleep(0.1)
132+
continue
133+
134+
# Update state if proc_mesh recovered
135+
if self.state == ReplicaState.RECOVERING and self.proc_mesh.healthy:
136+
self.state = ReplicaState.HEALTHY
137+
logger.debug("Replica %d recovered to healthy state", self.idx)
138+
139+
# If we're still recovering and proc_mesh isn't healthy, reject request
140+
if (
141+
self.state == ReplicaState.RECOVERING
142+
and not self.proc_mesh.healthy
143+
):
144+
request.future.set_exception(
145+
RuntimeError(f"Replica {self.idx} is still recovering")
146+
)
147+
self.request_queue.task_done()
148+
continue
149+
150+
# Process the request
151+
asyncio.create_task(self._process_single_request(request))
152+
153+
except asyncio.TimeoutError:
154+
# No requests, check for health state changes
155+
if self.state == ReplicaState.RECOVERING and self.proc_mesh.healthy:
156+
self.state = ReplicaState.HEALTHY
157+
logger.debug("Replica %d recovered to healthy state", self.idx)
158+
elif (
159+
self.state == ReplicaState.HEALTHY
160+
and not self.proc_mesh.healthy
161+
):
162+
self.state = ReplicaState.RECOVERING
163+
logger.debug("Replica %d entered recovering state", self.idx)
164+
continue
165+
166+
except Exception as e:
167+
logger.error(
168+
"Error in replica %d processing loop: %s",
169+
self.idx,
170+
e,
171+
)
172+
self.state = ReplicaState.UNHEALTHY
173+
break
174+
175+
finally:
176+
self._running = False
177+
logger.debug("Replica %d stopped processing", self.idx)
178+
179+
@property
180+
def healthy(self) -> bool:
181+
return self.state == ReplicaState.HEALTHY
182+
183+
@property
184+
def load(self) -> int:
185+
"""Get current load (active requests + queue depth)"""
186+
return self.active_requests + self.request_queue.qsize()
187+
188+
@property
189+
def capacity_utilization(self) -> float:
190+
"""Get current capacity utilization (0.0 to 1.0)"""
191+
if self.max_concurrent_requests <= 0:
192+
return 0.0
193+
return self.active_requests / self.max_concurrent_requests
194+
195+
def can_accept_request(self) -> bool:
196+
"""Check if replica can accept a new request"""
197+
return (
198+
self.state == ReplicaState.HEALTHY
199+
and self.active_requests < self.max_concurrent_requests
200+
)
201+
202+
def __repr__(self) -> str:
203+
return (
204+
f"Replica(idx={self.idx}, state={self.state.value}, "
205+
f"active={self.active_requests}/{self.max_concurrent_requests}, "
206+
f"queue={self.request_queue.qsize()})"
207+
)
208+
209+
async def setup(self):
210+
"""
211+
Sets up the replica and transitions to healthy state.
212+
213+
This should be called after the proc_mesh has been initialized
214+
and the actor has been spawned on it.
215+
"""
216+
if self.state != ReplicaState.UNINITIALIZED:
217+
logger.warning(
218+
"Attempting to setup replica %d that's already initialized", self.idx
219+
)
220+
return
221+
222+
if self.actor is None:
223+
raise RuntimeError(f"Cannot setup replica {self.idx}: actor is None")
224+
225+
try:
226+
# Call actor setup if it exists
227+
if hasattr(self.actor, "setup"):
228+
await self.actor.setup.call()
229+
230+
# Transition to healthy state
231+
self.state = ReplicaState.HEALTHY
232+
logger.debug("Replica %d setup complete", self.idx)
233+
234+
except Exception as e:
235+
logger.error("Failed to setup replica %d: %s", self.idx, e)
236+
self.state = ReplicaState.UNHEALTHY
237+
raise
238+
239+
async def stop(self):
240+
"""
241+
Stops the replica gracefully.
242+
243+
Transitions to STOPPED state, stops the processing loop, and cleans up.
244+
Fails any remaining requests in the queue.
245+
"""
246+
logger.debug("Stopping replica %d", self.idx)
247+
248+
# Transition to stopped state to signal the run loop to exit
249+
self.state = ReplicaState.STOPPED
250+
251+
# Wait for processor to finish if it's running
252+
if self._running:
253+
# Give it a moment to finish current request and exit gracefully
254+
for _ in range(50): # Wait up to 5 seconds
255+
if not self._running:
256+
break
257+
await asyncio.sleep(0.1)
258+
259+
if self._running:
260+
logger.warning("Replica %d processor didn't stop gracefully", self.idx)
261+
262+
# Fail any remaining requests in the queue
263+
failed_requests = []
264+
while not self.request_queue.empty():
265+
try:
266+
request = self.request_queue.get_nowait()
267+
failed_requests.append(request)
268+
self.request_queue.task_done()
269+
except asyncio.QueueEmpty:
270+
break
271+
272+
# Fail all the collected requests
273+
for request in failed_requests:
274+
if not request.future.done():
275+
request.future.set_exception(
276+
RuntimeError(f"Replica {self.idx} is stopping")
277+
)
278+
279+
logger.debug(
280+
"Replica %d stopped, failed %d remaining requests",
281+
self.idx,
282+
len(failed_requests),
283+
)
284+
285+
# Stop the proc_mesh
286+
try:
287+
await self.proc_mesh.stop()
288+
except Exception as e:
289+
logger.warning("Error stopping proc_mesh for replica %d: %s", self.idx, e)

0 commit comments

Comments
 (0)