diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index a191d931a..a800eb14c 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from .actor import ForgeActor from .proc_mesh import get_proc_mesh, spawn_actors -from .recoverable_mesh import RecoverableProcMesh from .service import Service, ServiceConfig from .spawn import spawn_service @@ -16,5 +15,4 @@ "spawn_actors", "get_proc_mesh", "ForgeActor", - "RecoverableProcMesh", ] diff --git a/src/forge/controller/recoverable_mesh.py b/src/forge/controller/recoverable_mesh.py deleted file mode 100644 index d352eab17..000000000 --- a/src/forge/controller/recoverable_mesh.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Recoverable Process Mesh - -This module provides a fault-tolerant wrapper around ProcMesh that automatically -recovers from crashes and failures. The RecoverableProcMesh class maintains the -same API as ProcMesh while adding automatic recovery capabilities. - -Key Features: -- **Automatic Recovery**: Detects mesh failures and automatically respawns processes -- **State Management**: Tracks mesh health and recovery status -- **Graceful Degradation**: Handles failures without losing the entire service -- **Context Management**: Supports async context manager for resource cleanup -- **Actor Respawning**: Automatically respawns actors after mesh recovery - -Example: - Basic usage with automatic recovery: - - >>> mesh = RecoverableProcMesh(num_gpus=2) - >>> - >>> async def spawn_actor(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass, *args) - ... return actor - >>> - >>> await mesh.spawn(spawn_actor) - >>> # Mesh will automatically recover if it fails - - Context manager usage: - - >>> async with RecoverableProcMesh(num_gpus=1) as mesh: - ... await mesh.spawn(spawn_actor) - ... # Mesh automatically cleaned up on exit -""" - -import asyncio -import logging -from enum import Enum -from typing import Any, Callable, Coroutine, Optional, TypeVar - -from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice -from monarch._src.actor.actor_mesh import Actor -from monarch._src.actor.proc_mesh import ProcMesh -from monarch._src.actor.shape import MeshTrait - -from forge.controller.proc_mesh import get_proc_mesh -from forge.types import ProcessConfig - -T = TypeVar("T", bound=Actor) -logger: logging.Logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class MeshState(Enum): - """ - Enumeration of possible mesh states for tracking recovery status. - - States: - HEALTHY: Mesh is operational and ready to handle requests - RECOVERING: Mesh is in the process of recovering from a failure - UNHEALTHY: Mesh has failed and needs recovery - STOPPED: Mesh has been explicitly stopped and cannot be used - """ - - HEALTHY = 0 - RECOVERING = 1 - UNHEALTHY = 2 - STOPPED = 3 - - -class RecoverableProcMesh(MeshTrait): - """ - A fault-tolerant wrapper around ProcMesh with automatic crash recovery. - - This class provides the same API as ProcMesh while adding robust failure detection - and automatic recovery capabilities. When the underlying mesh crashes or becomes - unresponsive, it automatically creates a new mesh and respawns all actors. - - The RecoverableProcMesh maintains state tracking to ensure proper recovery sequencing - and prevents resource leaks during failure scenarios. It's designed for long-running - services that need high availability. - - Args: - proc_config: ProcessConfig containing mesh configuration including num_procs - - Attributes: - num_procs: Number of processes allocated to this mesh - state: Current state of the mesh (HEALTHY, RECOVERING, UNHEALTHY, STOPPED) - healthy: True if the mesh is operational and ready for requests - failed: True if the mesh has failed and needs recovery - - Example: - Basic usage with automatic recovery: - - >>> proc_config = ProcessConfig(num_procs=2, scheduler="local") - >>> mesh = RecoverableProcMesh(proc_config) - >>> - >>> async def setup_actor(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass) - ... await actor.initialize.call() - >>> - >>> await mesh.spawn(setup_actor) - >>> # If mesh fails, it will automatically recover and re-run setup_actor - - Context manager for automatic cleanup: - - >>> proc_config = ProcessConfig(num_procs=1) - >>> async with RecoverableProcMesh(proc_config) as mesh: - ... await mesh.spawn(setup_actor) - ... # Use mesh for operations - ... # Mesh automatically stopped and cleaned up on exit - - Manual state checking: - - >>> if mesh.healthy: - ... # Safe to use mesh - ... pass - >>> elif mesh.failed: - ... # Mesh needs recovery - ... await mesh.spawn(setup_actor) # Triggers recovery - """ - - def __init__( - self, - proc_config: ProcessConfig, - ) -> None: - self._proc_config: ProcessConfig = proc_config - self.num_procs = proc_config.num_procs - self._proc_mesh: Optional[ProcMesh] = None - self._recovery_task: Optional[asyncio.Task[None]] = None - self.state: MeshState = MeshState.UNHEALTHY - - async def spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - """ - Spawn actors on the mesh with automatic recovery. - - This method ensures the mesh is healthy before spawning actors. If the mesh - has failed, it automatically triggers recovery and then executes the spawn hook. - The hook function receives the underlying ProcMesh and should handle actor - creation and initialization. - - Args: - hook: Async function that receives a ProcMesh and spawns/initializes actors - - Example: - >>> async def setup_actors(proc_mesh): - ... actor = await proc_mesh.spawn("MyActor", MyActorClass) - ... await actor.setup.call() - >>> - >>> await mesh.spawn(setup_actors) - """ - await self._background_spawn(hook) - - def trigger_spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - self._background_spawn(hook) - - def _background_spawn( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> asyncio.Task[None]: - if self.state == MeshState.STOPPED: - logger.warning("ProcMesh was already stopped when trying to spawn") - - self.state = MeshState.RECOVERING - self._recovery_task = asyncio.create_task(self._recover(hook)) - - return self._recovery_task - - def gpus(self) -> int: - return self.num_procs - - async def _recover( - self, hook: Callable[[ProcMesh], Coroutine[Any, Any, None]] - ) -> None: - self.state = MeshState.RECOVERING - - old_proc_mesh = self._proc_mesh - self._proc_mesh = None - - if old_proc_mesh is not None: - try: - await old_proc_mesh.stop() - except Exception as e: - logger.warning(f"Error stopping old ProcMesh: {e}") - - try: - self._proc_mesh = await get_proc_mesh(process_config=self._proc_config) - if self._proc_mesh is not None: - await hook(self._proc_mesh) - self.state = MeshState.HEALTHY - - except Exception as e: - logger.exception(f"Recovery attempt failed: {e}") - self.state = MeshState.UNHEALTHY - - @property - def healthy(self) -> bool: - return self.state == MeshState.HEALTHY - - @property - def failed(self) -> bool: - return self.state == MeshState.UNHEALTHY - - async def stop(self) -> None: - """ - Stop the mesh and clean up all resources. - - Gracefully shuts down the underlying ProcMesh and marks this recoverable - mesh as stopped. Once stopped, the mesh cannot be used for further operations. - - This method is idempotent - calling it multiple times is safe. - - Example: - >>> await mesh.stop() - >>> # Mesh is now stopped and cannot be used - """ - logger.info("Stopping RecoverableProcMesh") - if self.state == MeshState.STOPPED: - logger.info("RecoverableProcMesh was already stopped") - return - try: - if self._proc_mesh is not None: - await self._proc_mesh.stop() - except RuntimeError as e: - logger.warning("RecoverableProcMesh could not be stopped: %s", e) - - self.state = MeshState.STOPPED - - async def __aenter__(self) -> "RecoverableProcMesh": - """Enter the async context manager.""" - if self.state == MeshState.STOPPED: - raise RuntimeError("RecoverableProcMesh has already been stopped") - return self - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - """Exit the async context manager.""" - # In case there are multiple nested "async with" statements, we only - # want it to close once. - if self.state != MeshState.STOPPED: - await self.stop() - - def mark_failed(self): - """ - Mark the mesh as failed, triggering recovery on next spawn. - - This method is typically called when an operation on the mesh fails - or when external monitoring detects that the mesh is unresponsive. - The next call to spawn() will trigger automatic recovery. - - Example: - >>> try: - ... # Some operation that might fail - ... await actor.some_method.call() - >>> except Exception: - ... mesh.mark_failed() # Mark for recovery - """ - self.state = MeshState.UNHEALTHY - - @property - def _shape(self) -> Shape: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._shape - - @property - def _ndslice(self) -> Slice: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._ndslice - - @property - def _labels(self) -> list[str]: - if self._proc_mesh is None: - raise RuntimeError("ProcMesh not initialized") - return self._proc_mesh._labels - - def _new_with_shape(self, shape: Shape) -> "RecoverableProcMesh": - raise NotImplementedError( - "RecoverableProcMesh does not support _new_with_shape" - ) diff --git a/src/forge/controller/replica.py b/src/forge/controller/replica.py new file mode 100644 index 000000000..fec84d0e3 --- /dev/null +++ b/src/forge/controller/replica.py @@ -0,0 +1,505 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Replica for distributed actor service.""" + +import asyncio +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +from monarch.actor import Actor, ActorError, ProcMesh + +from forge.controller import get_proc_mesh +from forge.types import ProcessConfig + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class ReplicaState(Enum): + HEALTHY = "HEALTHY" + RECOVERING = "RECOVERING" + UNHEALTHY = "UNHEALTHY" + STOPPED = "STOPPED" + UNINITIALIZED = "UNINITIALIZED" + + +@dataclass +class ReplicaMetrics: + """Simple metrics tracking for a replica.""" + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + request_times: deque = field(default_factory=lambda: deque(maxlen=100)) + request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) + + def add_request_start(self, timestamp: float): + """Records when a request starts processing.""" + self.request_times.append(timestamp) + self.total_requests += 1 + + def add_request_completion(self, start_time: float, success: bool): + """Records when a request completes.""" + latency = time.time() - start_time + self.request_latencies.append(latency) + if success: + self.successful_requests += 1 + else: + self.failed_requests += 1 + + def get_request_rate(self, window_seconds: float = 60.0) -> float: + """Gets requests per second over the last window_seconds.""" + now = time.time() + cutoff = now - window_seconds + recent_requests = [t for t in self.request_times if t >= cutoff] + return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 + + def get_avg_latency(self, window_requests: int = 50) -> float: + """Gets average latency over the last N requests.""" + if not self.request_latencies: + return 0.0 + recent_latencies = list(self.request_latencies)[-window_requests:] + return sum(recent_latencies) / len(recent_latencies) + + +@dataclass +class ServiceRequest: + """Representation of a request to the service. + + A service request will typically be a call to an actor endpoint. + - The endpoint call is represented by function str/args/kwargs, + - The session_id is used for stateful routing, and + - The future is used to return the result of the call. + + """ + + session_id: Optional[str] + function: str + args: tuple + kwargs: dict + future: asyncio.Future + + +@dataclass +class Replica: + """ + A distributed replica that serves as the fundamental unit of work within a service. + + Handles process lifecycle, async request queuing and fault recovery. + Each replica runs independently and can be deployed across multiple hosts via Monarch + + """ + + idx: int + + # Configuration for the underlying ProcMesh (scheduler, hosts, GPUs) + proc_config: ProcessConfig + + # The proc_mesh and actor_mesh that this replica is running + proc_mesh: Optional[ProcMesh] = None + actor: Optional[Actor] = None + + # Async queue for incoming requests + request_queue: asyncio.Queue[ServiceRequest] = field(default_factory=asyncio.Queue) + # Number of currently processing requests + active_requests: int = 0 + # Maximum number of simultaneous requests + max_concurrent_requests: int = 10 + # Whether the processing loop is currently running + _running: bool = False + # How often to check for new requests when idle + _run_poll_rate_s: float = 1.0 + # Current replica health state + state: ReplicaState = ReplicaState.UNINITIALIZED + # Whether to auto-unwrap ValueMesh to first rank + return_first_rank_result: bool = False + + # Recovery-related state + _recovery_task: Optional[asyncio.Task] = None + + # Run task is the replica's event loop + _run_task: Optional[asyncio.Task] = None + + # Metrics tracking + metrics: ReplicaMetrics = field(default_factory=ReplicaMetrics) + + # Initialization related functionalities + + async def init_proc_mesh(self): + """Initializes the proc_mesh using the stored proc_config.""" + # TODO - for policy replica, we would override this method to + # include multiple proc_meshes + if self.proc_mesh is not None: + logger.warning("Proc mesh already initialized for replica %d", self.idx) + return + + logger.debug("Initializing proc_mesh for replica %d", self.idx) + try: + self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) + logger.debug("Proc mesh initialized successfully for replica %d", self.idx) + except Exception as e: + logger.error( + "Failed to initialize proc_mesh for replica %d: %s", self.idx, e + ) + self.state = ReplicaState.UNHEALTHY + raise + + async def spawn_actor(self, actor_def, *actor_args, **actor_kwargs): + """ + Spawn an actor on this replica's proc_mesh. + + This method handles the complete actor spawning process including + recovery if the proc_mesh has failed. + """ + # Ensure we have a healthy proc_mesh + await self._ensure_healthy_proc_mesh() + + if not self.proc_mesh: + raise RuntimeError( + f"Replica {self.idx}: proc_mesh is None after recovery attempt" + ) + + try: + # Determine actor name + if "name" in actor_kwargs: + actor_name = actor_kwargs.pop("name") + else: + actor_name = actor_def.__name__ + + # Spawn the actor + self.actor = await self.proc_mesh.spawn( + actor_name, + actor_def, + *actor_args, + **actor_kwargs, + ) + + # Call setup if it exists + await self.setup() + + logger.debug("Actor spawned successfully on replica %d", self.idx) + + except Exception as e: + logger.error("Failed to spawn actor on replica %d: %s", self.idx, e) + self.mark_failed() + raise + + async def setup(self): + """ + Sets up the replica and transitions to healthy state. + + This should be called after the proc_mesh has been initialized + and the actor has been spawned on it. + """ + if self.state != ReplicaState.UNINITIALIZED: + logger.warning( + "Attempting to setup replica %d that's already initialized", self.idx + ) + return + + if self.actor is None: + raise RuntimeError(f"Cannot setup replica {self.idx}: actor is None") + + try: + # Call actor setup if it exists + if hasattr(self.actor, "setup"): + # TODO - should this be a standard in our Forge Actor(s)? + await self.actor.setup.call() + + # Transition to healthy state and start processing + self.state = ReplicaState.HEALTHY + self.start_processing() + logger.debug("Replica %d setup complete", self.idx) + + except Exception as e: + logger.error("Failed to setup replica %d: %s", self.idx, e) + self.state = ReplicaState.UNHEALTHY + raise + + # Request handling / processing related functionality + + def start_processing(self): + """Start the replica's processing loop if not already running.""" + if self._run_task is None or self._run_task.done(): + self._run_task = asyncio.create_task(self.run()) + logger.debug("Started processing loop for replica %d", self.idx) + + async def enqueue_request(self, request: ServiceRequest): + """Enqueues a request for processing by this replica.""" + if self.state == ReplicaState.STOPPED: + raise RuntimeError( + f"Replica {self.idx} is stopped and therefore will not accept requests." + ) + + # Accept requests in all other states - let the processing loop handle the rest + await self.request_queue.put(request) + + async def _process_single_request(self, request: ServiceRequest) -> bool: + """Processes a single request and returns success status. + + Returns: + bool: True if request succeeded, False if it failed + """ + start_time = time.time() + self.active_requests += 1 + + # Record request start for metrics + self.metrics.add_request_start(start_time) + + try: + # Get the actor and endpoint + actor = self.actor + endpoint_func = getattr(actor, request.function) + + # Execute the request + success = True + try: + result = await endpoint_func.call(*request.args, **request.kwargs) + # Unwrap ValueMesh if configured to return first rank result + if ( + self.return_first_rank_result + and hasattr(result, "_values") + and result._values + ): + result = result._values[0] + request.future.set_result(result) + except ActorError as e: + logger.warning("Got failure on replica %d. Error:\n%s", self.idx, e) + # The exception came from the actor. It itself is + # returned to be propagated through the services + # back to the caller. + request.future.set_result(e.exception) + + # TODO: we may want to conditionally mark the + # replica as failed here - i.e. where the actor itself + # can be healthy but the request failed. + self.mark_failed() + success = False + except Exception as e: + logger.debug( + "Got unexpected error on replica %d. Error:\n%s", self.idx, e + ) + self.mark_failed() + + # The exception was not from the actor - in this case + # we will signal back to the service (through set_exception) + # to retry on another healthy node. + request.future.set_exception(e) + success = False + + self.metrics.add_request_completion(start_time, success) + # Mark task as done + self.request_queue.task_done() + return success + + finally: + self.active_requests -= 1 + + async def run(self): + """Runs the main processing loop for the replica. + + Continuously processes requests from the queue while the replica is healthy. + Handles capacity management and graceful degradation on failures. + """ + self._running = True + + try: + while self.state in (ReplicaState.HEALTHY, ReplicaState.RECOVERING): + try: + # Wait for a request with timeout to check health periodically + request = await asyncio.wait_for( + self.request_queue.get(), timeout=self._run_poll_rate_s + ) + + # Check if we have capacity - if we have too many ongoing, + # we will put the request back and wait. + if self.active_requests >= self.max_concurrent_requests: + await self.request_queue.put(request) + await asyncio.sleep(0.1) + continue + + # If we're recovering, reject the request + if self.state == ReplicaState.RECOVERING: + # This signals to the service to retry on another replica + request.future.set_exception( + RuntimeError(f"Replica {self.idx} is still recovering") + ) + self.request_queue.task_done() + continue + + # Process the request + asyncio.create_task(self._process_single_request(request)) + + except asyncio.TimeoutError: + # No requests, just continue checking for new ones + continue + + except Exception as e: + logger.error( + "Error in replica %d processing loop: %s", + self.idx, + e, + ) + self.state = ReplicaState.UNHEALTHY + break + + finally: + self._running = False + logger.debug("Replica %d stopped processing", self.idx) + + # Replica state management + + @property + def healthy(self) -> bool: + return self.state == ReplicaState.HEALTHY + + @property + def failed(self) -> bool: + """Check if the replica has failed and needs recovery.""" + return self.state in (ReplicaState.RECOVERING, ReplicaState.UNHEALTHY) + + def mark_failed(self): + """Mark the replica as failed, triggering recovery.""" + logger.debug("Marking replica %d as failed", self.idx) + self.state = ReplicaState.RECOVERING + + async def _ensure_healthy_proc_mesh(self): + """Ensure we have a healthy proc_mesh, recovering if necessary.""" + if self.failed: + await self._recover() + + async def _recover(self): + """ + Recover the replica by recreating the proc_mesh and respawning actors. + + This is the core recovery logic moved from RecoverableProcMesh. + """ + if self._recovery_task and not self._recovery_task.done(): + # Recovery already in progress, wait for it + await self._recovery_task + return + + logger.debug("Starting recovery for replica %d", self.idx) + self.state = ReplicaState.RECOVERING + + # Create the recovery task + self._recovery_task = asyncio.create_task(self._do_recovery()) + await self._recovery_task + + async def _do_recovery(self): + """Internal method that performs the actual recovery work.""" + old_proc_mesh = self.proc_mesh + self.proc_mesh = None + self.actor = None + + # Stop old proc_mesh if it exists + if old_proc_mesh is not None: + try: + await old_proc_mesh.stop() + logger.debug("Old proc_mesh stopped for replica %d", self.idx) + except Exception as e: + logger.warning( + "Error stopping old proc_mesh for replica %d: %s", self.idx, e + ) + + # Create new proc_mesh + try: + logger.debug("Creating new proc_mesh for replica %d", self.idx) + self.proc_mesh = await get_proc_mesh(process_config=self.proc_config) + self.state = ReplicaState.HEALTHY + logger.debug("Recovery completed successfully for replica %d", self.idx) + + except Exception as e: + logger.error("Recovery failed for replica %d: %s", self.idx, e) + self.state = ReplicaState.UNHEALTHY + raise + + async def stop(self): + """ + Stops the replica gracefully. + + Transitions to STOPPED state, stops the processing loop, and cleans up. + Fails any remaining requests in the queue. + """ + logger.debug("Stopping replica %d", self.idx) + + # Transition to stopped state to signal the run loop to exit + self.state = ReplicaState.STOPPED + + # Wait for processor to finish if it's running + if self._running: + # Give it a moment to finish current request and exit gracefully + for _ in range(50): # Wait up to 5 seconds + if not self._running: + break + await asyncio.sleep(0.1) + + if self._running: + logger.warning("Replica %d processor didn't stop gracefully", self.idx) + + # Fail any remaining requests in the queue + failed_requests = [] + while not self.request_queue.empty(): + try: + request = self.request_queue.get_nowait() + failed_requests.append(request) + self.request_queue.task_done() + except asyncio.QueueEmpty: + break + + # Fail all the collected requests + for request in failed_requests: + if not request.future.done(): + request.future.set_exception( + RuntimeError(f"Replica {self.idx} is stopping") + ) + + logger.debug( + "Replica %d stopped, failed %d remaining requests", + self.idx, + len(failed_requests), + ) + + # Stop the proc_mesh + if self.proc_mesh: + try: + await self.proc_mesh.stop() + except Exception as e: + logger.warning( + "Error stopping proc_mesh for replica %d: %s", self.idx, e + ) + + # Metric-related getters + + @property + def load(self) -> int: + """Get current load (active requests + queue depth)""" + return self.active_requests + self.request_queue.qsize() + + @property + def capacity_utilization(self) -> float: + """Get current capacity utilization (0.0 to 1.0)""" + if self.max_concurrent_requests <= 0: + return 0.0 + return self.active_requests / self.max_concurrent_requests + + def can_accept_request(self) -> bool: + """Check if replica can accept a new request""" + return ( + self.state == ReplicaState.HEALTHY + and self.active_requests < self.max_concurrent_requests + ) + + def __repr__(self) -> str: + return ( + f"Replica(idx={self.idx}, state={self.state.value}, " + f"active={self.active_requests}/{self.max_concurrent_requests}, " + f"queue={self.request_queue.qsize()})" + ) diff --git a/src/forge/controller/service.py b/src/forge/controller/service.py index 13c58db36..f14c5b93c 100644 --- a/src/forge/controller/service.py +++ b/src/forge/controller/service.py @@ -32,95 +32,27 @@ ... result = await service.my_endpoint(arg1, arg2) """ - import asyncio import contextvars import logging import pprint -import time import uuid -from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, Dict, List, Optional +from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty -from monarch.actor import ActorError, ProcMesh -from forge.controller import RecoverableProcMesh +from forge.controller.replica import Replica, ReplicaMetrics, ServiceRequest from forge.types import ServiceConfig logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) - -# TODO - tie this into metric logger when it exists -@dataclass -class ReplicaMetrics: - """ - Metrics collection for a single replica instance. - - Tracks request counts, timing metrics, current state, and session assignments - for performance monitoring and autoscaling decisions. - - Attributes: - replica_idx: Unique identifier for this replica - total_requests: Total number of requests processed - successful_requests: Number of successfully completed requests - failed_requests: Number of failed requests - request_times: Sliding window of request start timestamps - request_latencies: Sliding window of request completion latencies - active_requests: Currently processing requests - queue_depth: Number of pending requests in queue - assigned_sessions: Number of sessions assigned to this replica - """ - - replica_idx: int - # Request metrics - total_requests: int = 0 - successful_requests: int = 0 - failed_requests: int = 0 - # Timing metrics (sliding window) - request_times: deque = field(default_factory=lambda: deque(maxlen=100)) - request_latencies: deque = field(default_factory=lambda: deque(maxlen=100)) - # Current state - active_requests: int = 0 - queue_depth: int = 0 - # Session metrics - assigned_sessions: int = 0 - - def add_request_start(self, timestamp: float): - """Record when a request starts processing.""" - self.request_times.append(timestamp) - self.total_requests += 1 - - def add_request_completion(self, start_time: float, success: bool): - """Record when a request completes.""" - latency = time.time() - start_time - self.request_latencies.append(latency) - if success: - self.successful_requests += 1 - else: - self.failed_requests += 1 - - def get_request_rate(self, window_seconds: float = 60.0) -> float: - """Get requests per second over the last window_seconds.""" - now = time.time() - cutoff = now - window_seconds - recent_requests = [t for t in self.request_times if t >= cutoff] - return len(recent_requests) / window_seconds if window_seconds > 0 else 0.0 - - def get_avg_latency(self, window_requests: int = 50) -> float: - """Get average latency over the last N requests.""" - if not self.request_latencies: - return 0.0 - recent_latencies = list(self.request_latencies)[-window_requests:] - return sum(recent_latencies) / len(recent_latencies) - - def get_capacity_utilization(self, max_concurrent: int) -> float: - """Get current capacity utilization (0.0 to 1.0).""" - return self.active_requests / max_concurrent if max_concurrent > 0 else 0.0 +P = ParamSpec("P") +R = TypeVar("R") +# TODO - tie this into metrics logger when it exists. @dataclass class ServiceMetrics: """ @@ -153,79 +85,67 @@ def get_total_request_rate(self, window_seconds: float = 60.0) -> float: for metrics in self.replica_metrics.values() ) - def get_avg_queue_depth(self) -> float: + def get_avg_queue_depth(self, replicas: List) -> float: """Get average queue depth across all healthy replicas.""" - healthy_metrics = [ - m - for m in self.replica_metrics.values() - if m.replica_idx < self.healthy_replicas - ] - if not healthy_metrics: + healthy_replicas = [r for r in replicas if r.healthy] + if not healthy_replicas: return 0.0 - return sum(m.queue_depth for m in healthy_metrics) / len(healthy_metrics) + total_queue_depth = sum(r.request_queue.qsize() for r in healthy_replicas) + return total_queue_depth / len(healthy_replicas) def get_avg_capacity_utilization(self, replicas: List) -> float: """Get average capacity utilization across all healthy replicas.""" - healthy_replicas = [r for r in replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in replicas if r.healthy] if not healthy_replicas: return 0.0 - - utilizations = [] - for replica in healthy_replicas: - if replica.idx in self.replica_metrics: - metrics = self.replica_metrics[replica.idx] - utilization = metrics.get_capacity_utilization( - replica.max_concurrent_requests - ) - utilizations.append(utilization) - - return sum(utilizations) / len(utilizations) if utilizations else 0.0 + total_utilization = sum(r.capacity_utilization for r in healthy_replicas) + return total_utilization / len(healthy_replicas) def get_sessions_per_replica(self) -> float: - """Get average sessions per healthy replica.""" - if self.healthy_replicas == 0: + """Get average sessions per replica.""" + if self.total_replicas == 0: return 0.0 - return self.total_sessions / self.healthy_replicas + return self.total_sessions / self.total_replicas -@dataclass -class Replica: - proc_mesh: RecoverableProcMesh - actor: Any - idx: int - request_queue: asyncio.Queue[dict] = field(default_factory=asyncio.Queue) - active_requests: int = 0 - max_concurrent_requests: int = 10 - _processor_running: bool = False - metadata: dict = field(default_factory=dict) +# Context variable for session state +_session_context = contextvars.ContextVar("session_context") @dataclass class Session: + """Simple session data holder.""" + session_id: str -# Global context variable for session state -# This is used to propagate session state across async tasks -_session_context: contextvars.ContextVar[dict | None] = contextvars.ContextVar( - "session_context", default=None -) +class SessionContext: + """ + Async context manager for stateful service sessions with automatic lifecycle management. + + Provides a convenient way to maintain stateful connections to replicas across multiple + requests. Sessions ensure that all requests within the context are routed to the same + replica, enabling stateful interactions while handling session lifecycle automatically. + + Example: + >>> async with service.session() as session: + ... # All calls within this block use the same replica + ... result1 = await service.my_endpoint(arg1) + ... result2 = await service.another_endpoint(result1) -class SessionContext: - """Context manager for service sessions using context variables.""" + """ - def __init__(self, service: "Service", **session_kwargs): + def __init__(self, service: "Service"): self.service = service self.session_id: str | None = None - self.session_kwargs = session_kwargs self._token = None async def __aenter__(self): """Start a session and set context variables.""" self.session_id = await self.service.start_session() # Set context for this async task - context_value = {"session_id": self.session_id, "kwargs": self.session_kwargs} + context_value = {"session_id": self.session_id} self._token = _session_context.set(context_value) return self @@ -238,6 +158,34 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.session_id = None +class ServiceEndpoint(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, service: "Service", endpoint_name: str): + self.service = service + self.endpoint_name = endpoint_name + + async def choose( + self, sess_id: str | None = None, *args: P.args, **kwargs: P.kwargs + ) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + return await self.service._call_all(self.endpoint_name, *args, **kwargs) + + class Service: """ Distributed Actor Service Controller @@ -299,11 +247,6 @@ def __init__(self, cfg: ServiceConfig, actor_def, *actor_args, **actor_kwargs): # Initialize metrics collection self._metrics = ServiceMetrics() - - # Autoscaling state - self._last_scale_up_time = 0.0 - self._last_scale_down_time = 0.0 - self._low_utilization_start_time = None self._health_task = None self._shutdown_requested = False @@ -325,12 +268,11 @@ async def __initialize__(self): replicas = [] num_replicas = self._cfg.num_replicas for i in range(num_replicas): - mesh = RecoverableProcMesh(proc_config=self._cfg.to_process_config()) replica = Replica( - proc_mesh=mesh, - actor=None, idx=len(self._replicas) + i, + proc_config=self._cfg.to_process_config(), max_concurrent_requests=self._cfg.replica_max_concurrent_requests, + return_first_rank_result=self._cfg.return_first_rank_result, ) replicas.append(replica) @@ -356,13 +298,9 @@ async def __initialize__(self): ) def _add_endpoint_method(self, endpoint_name: str): - """Dynamically adds an endpoint method to this Service instance.""" - - async def endpoint_method(sess_id: str | None = None, *args, **kwargs): - return await self._call(sess_id, endpoint_name, *args, **kwargs) - - # Set the method on this instance - setattr(self, endpoint_name, endpoint_method) + """Dynamically adds a ServiceEndpoint instance to this Service instance.""" + endpoint = ServiceEndpoint(self, endpoint_name) + setattr(self, endpoint_name, endpoint) async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ @@ -390,37 +328,30 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ # Check context variables for session state if no explicit sess_id if sess_id is None: - ctx = _session_context.get() + ctx = _session_context.get(None) if ctx: sess_id = ctx["session_id"] - routing_hints = ctx["kwargs"] - else: - routing_hints = {} - else: - routing_hints = {} - - replica = await self._get_replica(sess_id, **routing_hints) - - # Create a request object to queue - request = { - "sess_id": sess_id, - "function": function, - "args": args, - "kwargs": kwargs, - "future": asyncio.Future(), - } - # Queue the request - await replica.request_queue.put(request) - # Ensure the replica has a processor running - self._ensure_processor_running(replica) + replica = await self._get_replica(sess_id) + + # Create a ServiceRequest object to queue + request = ServiceRequest( + session_id=sess_id, + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: - return await request["future"] + return await request.future except Exception as e: # If the replica failed, try to retry once - if not replica.proc_mesh.healthy: + if not replica.healthy: logger.debug( "Replica %d failed during request, retrying on healthy replica", replica.idx, @@ -430,99 +361,56 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - def _ensure_processor_running(self, replica: Replica): - """Ensures a persistent processor is running for this replica.""" - if not replica._processor_running: - replica._processor_running = True - asyncio.create_task(self._persistent_processor(replica)) + async def _call_all(self, function: str, *args, **kwargs) -> List: + """ + Broadcasts a function call to all healthy replicas and returns results as a list. - async def _persistent_processor(self, replica: Replica): - """Persistent processor that continuously handles requests for a replica.""" - try: - while replica.proc_mesh.healthy: - try: - # Wait for a request with timeout to check health periodically - request = await asyncio.wait_for( - replica.request_queue.get(), timeout=1.0 - ) - - # Check if we have capacity - if replica.active_requests >= replica.max_concurrent_requests: - # Put the request back and wait - await replica.request_queue.put(request) - await asyncio.sleep(0.1) - continue - - # Process the request - asyncio.create_task(self._process_single_request(replica, request)) - - except asyncio.TimeoutError: - # No requests, continue to check health - continue - except Exception as e: - logger.error( - "Error in persistent processor for replica %d: %s", - replica.idx, - e, - ) - break - finally: - replica._processor_running = False - # Migrate any remaining requests to healthy replicas - await self._migrate_remaining_requests(replica) - - async def _process_single_request(self, replica: Replica, request: dict): - """Processes a single request.""" - start_time = time.time() - replica.active_requests += 1 - - # Get or create metrics for this replica - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.add_request_start(start_time) - replica_metrics.active_requests = replica.active_requests + Args: + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint - try: - # Get the actor and endpoint - actor = replica.actor - endpoint_func = getattr(actor, request["function"]) + Returns: + List of results from all healthy replicas + + Raises: + RuntimeError: If no healthy replicas are available + """ + healthy_replicas = [r for r in self._replicas if r.healthy] - # Execute the request - success = True + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for broadcast call") + + # Create requests for all healthy replicas + requests = [] + for replica in healthy_replicas: + request = ServiceRequest( + session_id=None, # Broadcast calls don't use sessions + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + requests.append((replica, request)) + + # Enqueue all requests + for replica, request in requests: + await replica.enqueue_request(request) + + # Wait for all results + results = [] + for replica, request in requests: try: - result = await endpoint_func.call(*request["args"], **request["kwargs"]) - if ( - self._cfg.return_first_rank_result - and hasattr(result, "_values") - and result._values - ): - result = result._values[0] - request["future"].set_result(result) - except ActorError as e: - logger.debug("Got failure on replica %d. Error:\n%s", replica.idx, e) - replica.proc_mesh.mark_failed() - # Unwrap the ActorError into its raw exception. - request["future"].set_result(e.exception) - success = False + result = await request.future + results.append(result) except Exception as e: - logger.debug( - "Got unexpected error on replica %d. Error:\n%s", replica.idx, e + logger.warning( + "Request to replica %d failed during broadcast: %s", replica.idx, e ) - replica.proc_mesh.mark_failed() - request["future"].set_result(e) - success = False + # Add None for failed replicas to maintain indexing + results.append(None) - # Record completion metrics - replica_metrics.add_request_completion(start_time, success) - - # Mark task as done - replica.request_queue.task_done() - - finally: - replica.active_requests -= 1 - replica_metrics.active_requests = replica.active_requests + return results async def _retry_request_on_healthy_replica( self, sess_id: str | None, function: str, *args, **kwargs @@ -558,13 +446,13 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): # Find healthy replicas healthy_replicas = [ - r for r in self._replicas if r.proc_mesh.healthy and r != failed_replica + r for r in self._replicas if r.healthy and r != failed_replica ] if not healthy_replicas: # No healthy replicas, fail all requests for request in migrated_requests: - request["future"].set_exception( + request.future.set_exception( RuntimeError("No healthy replicas available") ) return @@ -572,11 +460,10 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): # Distribute requests among healthy replicas for i, request in enumerate(migrated_requests): target_replica = healthy_replicas[i % len(healthy_replicas)] - await target_replica.request_queue.put(request) - self._ensure_processor_running(target_replica) + await target_replica.enqueue_request(request) # Update session mapping if needed - sess_id = request["sess_id"] + sess_id = request.session_id if ( sess_id in self._session_replica_map and self._session_replica_map[sess_id] == failed_replica.idx @@ -608,35 +495,20 @@ async def start_session(self) -> str: return sess_id - def session(self, **kwargs) -> SessionContext: + def session(self) -> SessionContext: """Returns a context manager for session-based calls.""" - return SessionContext(self, **kwargs) + return SessionContext(self) def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum( - 1 for r in self._replicas if r.proc_mesh.healthy - ) - - # Update queue depths for all replicas + self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + # Store direct references to replica metrics for aggregation + self._metrics.replica_metrics = {} for replica in self._replicas: - if replica.idx not in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica.idx] = ReplicaMetrics(replica.idx) - - replica_metrics = self._metrics.replica_metrics[replica.idx] - replica_metrics.queue_depth = replica.request_queue.qsize() - replica_metrics.active_requests = replica.active_requests - - # Update session assignments per replica - session_counts = defaultdict(int) - for sess_id, replica_idx in self._session_replica_map.items(): - session_counts[replica_idx] += 1 - - for replica_idx, count in session_counts.items(): - if replica_idx in self._metrics.replica_metrics: - self._metrics.replica_metrics[replica_idx].assigned_sessions = count + # Use the replica's own metrics directly + self._metrics.replica_metrics[replica.idx] = replica.metrics def get_metrics(self) -> ServiceMetrics: """ @@ -680,7 +552,7 @@ def get_metrics_summary(self) -> dict: "healthy_replicas": self._metrics.healthy_replicas, "total_replicas": self._metrics.total_replicas, "total_request_rate": self._metrics.get_total_request_rate(), - "avg_queue_depth": self._metrics.get_avg_queue_depth(), + "avg_queue_depth": self._metrics.get_avg_queue_depth(self._replicas), "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( self._replicas ), @@ -689,17 +561,26 @@ def get_metrics_summary(self) -> dict: "replicas": {}, } - for replica_idx, metrics in self._metrics.replica_metrics.items(): - summary["replicas"][replica_idx] = { + for replica in self._replicas: + metrics = replica.metrics + + # Count sessions assigned to this replica + assigned_sessions = sum( + 1 + for replica_idx in self._session_replica_map.values() + if replica_idx == replica.idx + ) + + summary["replicas"][replica.idx] = { "total_requests": metrics.total_requests, "successful_requests": metrics.successful_requests, "failed_requests": metrics.failed_requests, "request_rate": metrics.get_request_rate(), "avg_latency": metrics.get_avg_latency(), - "active_requests": metrics.active_requests, - "queue_depth": metrics.queue_depth, - "assigned_sessions": metrics.assigned_sessions, - "capacity_utilization": metrics.get_capacity_utilization(10), + "active_requests": replica.active_requests, # Get from replica + "queue_depth": replica.request_queue.qsize(), # Get from replica + "assigned_sessions": assigned_sessions, # Calculate from session map + "capacity_utilization": replica.capacity_utilization, # Get from replica } return summary @@ -749,7 +630,7 @@ async def _health_loop(self, poll_rate_s: float): # Check for failed replicas and recover them failed_replicas = [] for replica in self._replicas: - if replica.proc_mesh.failed: + if replica.failed: failed_replicas.append(replica) if any(failed_replicas): @@ -762,15 +643,9 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _custom_replica_routing( - self, sess_id: str | None, **kwargs - ) -> Optional[Replica]: - """Hook for custom routing logic. Override in subclasses to implement custom routing.""" - return None - def _get_next_replica(self) -> "Replica": """Get the next replica using round-robin selection.""" - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in self._replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for load balancing") @@ -780,7 +655,7 @@ def _get_next_replica(self) -> "Replica": def _get_least_loaded_replica(self) -> "Replica": """Get the replica with the lowest load.""" - healthy_replicas = [r for r in self._replicas if r.proc_mesh.healthy] + healthy_replicas = [r for r in self._replicas if r.healthy] if not healthy_replicas: raise RuntimeError("No healthy replicas available for session assignment") @@ -790,15 +665,8 @@ def get_load(replica: "Replica") -> int: return min(healthy_replicas, key=get_load) - async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": - """Get a replica for the given session ID, with optional custom routing hints.""" - # Try custom routing first if hints are provided - if kwargs: - custom_result = await self._custom_replica_routing(sess_id, **kwargs) - if custom_result is not None: - return custom_result - - # Default routing logic + async def _get_replica(self, sess_id: str | None) -> "Replica": + """Get a replica for the given session ID.""" if sess_id is None: # No session, use round-robin load balancing replica = self._get_next_replica() @@ -809,7 +677,7 @@ async def _get_replica(self, sess_id: str | None, **kwargs) -> "Replica": replica_idx = self._session_replica_map[sess_id] # Find the replica with this index for replica in self._replicas: - if replica.idx == replica_idx and replica.proc_mesh.healthy: + if replica.idx == replica_idx and replica.healthy: return replica # If the replica is no longer healthy, remove from session map and reassign del self._session_replica_map[sess_id] @@ -838,8 +706,9 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all replicas using their stop method await asyncio.gather( - *[replica.proc_mesh.stop() for replica in self._replicas], + *[replica.stop() for replica in self._replicas], return_exceptions=True, ) @@ -850,139 +719,32 @@ async def _maybe_init_replicas(self): logger.debug("Init replicas: %s", pprint.pformat(self._replicas_to_init)) - def _recover_hook( - replica: Replica, - ) -> Callable[[ProcMesh], Coroutine[Any, Any, None]]: - async def inner_hook(proc_mesh: ProcMesh) -> None: - if "name" in self._actor_kwargs: - actor_name = self._actor_kwargs.pop("name") - else: - actor_name = self._actor_def.__name__ - # TODO - expand support so name can stick within kwargs - actor = await proc_mesh.spawn( - actor_name, - self._actor_def, - *self._actor_args, - **self._actor_kwargs, - ) - replica.actor = actor - if hasattr(actor, "setup"): - await actor.setup.call() - - return inner_hook + # Initialize each replica (proc_mesh and actor spawning) + initialization_tasks = [] + for replica in self._replicas_to_init: + task = asyncio.create_task(self._init_single_replica(replica)) + initialization_tasks.append(task) - await asyncio.gather( - *[ - replica.proc_mesh.spawn(_recover_hook(replica)) - for replica in self._replicas_to_init - ] - ) + await asyncio.gather(*initialization_tasks, return_exceptions=True) self._replicas_to_init.clear() - async def _scale_up(self, num_replicas: int = 1): - """ - Scales up the service by adding new replicas. - - Creates new replica instances with their own process meshes and queues them - for initialization. The replicas will be initialized asynchronously by the - health loop to avoid blocking the scaling operation. - - Args: - num_replicas: Number of replicas to add (default: 1) - - Note: - Replicas are queued for initialization rather than initialized immediately - to prevent blocking during scaling operations. - """ - logger.debug("Scaling up with %d replicas.", num_replicas) - new_replicas = [] - for i in range(num_replicas): - mesh = RecoverableProcMesh( - self._cfg.procs_per_replica, - ) - replica = Replica( - proc_mesh=mesh, - actor=None, - idx=len(self._replicas) + i, - max_concurrent_requests=self._cfg.replica_max_concurrent_requests, - ) - new_replicas.append(replica) - - # Add to the initialization queue instead of initializing immediately - self._replicas_to_init.extend(new_replicas) - self._replicas.extend(new_replicas) - logger.debug( - "Queued %d replicas for initialization. Total replicas: %d", - num_replicas, - len(self._replicas), - ) - - async def _scale_down_replicas(self, num_replicas: int = 1): - """ - Scales down the service by intelligently removing replicas. - - Prioritizes removal of unhealthy replicas first, then selects healthy replicas - with the lowest load. Migrates all workload (sessions and queued requests) - from removed replicas to remaining healthy replicas. + async def _init_single_replica(self, replica: Replica): + """Initialize a single replica with proc_mesh and actor.""" + try: + # Initialize the proc_mesh + await replica.init_proc_mesh() - Args: - num_replicas: Number of replicas to remove (default: 1) - - Note: - # Test context manager usage - async with service.session(): - await service.incr() - await service.incr() - result = await service.value() - assert result == 2 - - Sessions are reassigned on their next request rather than immediately - to avoid disrupting active workloads. - """ - logger.debug("Scaling down by %d replicas.", num_replicas) - - # Find replicas to remove (prefer unhealthy ones first, then least loaded) - replicas_to_remove = [] - - # First, try to remove unhealthy replicas - unhealthy_replicas = [r for r in self._replicas if not r.proc_mesh.healthy] - for replica in unhealthy_replicas[:num_replicas]: - replicas_to_remove.append(replica) - - # If we need more, remove healthy replicas with least load - remaining_to_remove = num_replicas - len(replicas_to_remove) - if remaining_to_remove > 0: - healthy_replicas = [ - r - for r in self._replicas - if r.proc_mesh.healthy and r not in replicas_to_remove - ] - # Sort by load (queue depth + active requests) - healthy_replicas.sort( - key=lambda r: r.request_queue.qsize() + r.active_requests + # Spawn the actor using replica's method + await replica.spawn_actor( + self._actor_def, *self._actor_args, **self._actor_kwargs ) - for replica in healthy_replicas[:remaining_to_remove]: - replicas_to_remove.append(replica) - - # Migrate sessions and requests from replicas being removed - for replica in replicas_to_remove: - await self._migrate_replica_workload(replica) + logger.debug("Successfully initialized replica %d", replica.idx) - # Stop the replica - try: - await replica.proc_mesh.stop() - except Exception as e: - logger.warning("Error stopping replica %d: %s", replica.idx, e) - - # Remove from replicas list - self._replicas.remove(replica) - - # Update replica indices - for i, replica in enumerate(self._replicas): - replica.idx = i - - logger.debug("Scale down complete. Remaining replicas: %d", len(self._replicas)) + except Exception as e: + logger.error("Failed to initialize replica %d: %s", replica.idx, e) + # Mark as failed so it can be retried later + replica.mark_failed() async def _migrate_replica_workload(self, replica_to_remove: Replica): """Migrates all workload from a replica that's being removed.""" diff --git a/tests/test_service.py b/tests/test_service.py index 7283aeeee..791b6da23 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -66,8 +66,8 @@ async def test_basic_service_operations(): assert isinstance(session1, str) # Test endpoint calls - await service.incr(session1) - result = await service.value(session1) + await service.incr.choose(sess_id=session1) + result = await service.value.choose(sess_id=session1) assert result == 1 # Test session mapping @@ -90,9 +90,9 @@ async def test_sessionless_calls(): try: # Test sessionless calls - await service.incr() - await service.incr() - result = await service.value() + await service.incr.choose() + await service.incr.choose() + result = await service.value.choose() assert result is not None # No sessions should be created @@ -121,18 +121,18 @@ async def test_session_context_manager(): try: # Test context manager usage async with service.session(): - await service.incr() - await service.incr() - result = await service.value() + await service.incr.choose() + await service.incr.choose() + result = await service.value.choose() assert result == 2 # Test sequential context managers to avoid interference async def worker(increments: int): async with service.session(): - initial = await service.value() + initial = await service.value.choose() for _ in range(increments): - await service.incr() - final = await service.value() + await service.incr.choose() + final = await service.value.choose() return final - initial # Run sessions sequentially to avoid concurrent modification @@ -162,28 +162,28 @@ async def test_replica_failure_and_recovery(): try: # Create session and cause failure session = await service.start_session() - await service.incr(session) + await service.incr.choose(session) original_replica_idx = service._session_replica_map[session] # Cause failure - error_result = await service.fail_me(session) + error_result = await service.fail_me.choose(session) assert isinstance(error_result, RuntimeError) # Replica should be marked as failed failed_replica = service._replicas[original_replica_idx] - assert not failed_replica.proc_mesh.healthy + assert not failed_replica.healthy # Session should be reassigned on next call - await service.incr(session) + await service.incr.choose(session) new_replica_idx = service._session_replica_map[session] assert new_replica_idx != original_replica_idx # New sessions should avoid failed replica new_session = await service.start_session() - await service.incr(new_session) + await service.incr.choose(new_session) assigned_replica = service._replicas[service._session_replica_map[new_session]] - assert assigned_replica.proc_mesh.healthy + assert assigned_replica.healthy finally: await service.stop() @@ -204,12 +204,12 @@ async def test_metrics_collection(): session1 = await service.start_session() session2 = await service.start_session() - await service.incr(session1) - await service.incr(session1) - await service.incr(session2) + await service.incr.choose(session1) + await service.incr.choose(session1) + await service.incr.choose(session2) # Test failure metrics - error_result = await service.fail_me(session1) + error_result = await service.fail_me.choose(session1) assert isinstance(error_result, RuntimeError) # Get metrics @@ -256,18 +256,18 @@ async def test_session_stickiness(): session = await service.start_session() # Make multiple calls - await service.incr(session) - await service.incr(session) - await service.incr(session) + await service.incr.choose(session) + await service.incr.choose(session) + await service.incr.choose(session) # Should always route to same replica replica_idx = service._session_replica_map[session] - await service.incr(session) + await service.incr.choose(session) assert service._session_replica_map[session] == replica_idx # Verify counter was incremented correctly - result = await service.value(session) + result = await service.value.choose(session) assert result == 4 finally: @@ -284,16 +284,16 @@ async def test_load_balancing_multiple_sessions(): try: # Create sessions with some load to trigger distribution session1 = await service.start_session() - await service.incr(session1) # Load replica 0 + await service.incr.choose(session1) # Load replica 0 session2 = await service.start_session() - await service.incr(session2) # Should go to replica 1 (least loaded) + await service.incr.choose(session2) # Should go to replica 1 (least loaded) session3 = await service.start_session() - await service.incr(session3) # Should go to replica 0 or 1 based on load + await service.incr.choose(session3) # Should go to replica 0 or 1 based on load session4 = await service.start_session() - await service.incr(session4) # Should balance the load + await service.incr.choose(session4) # Should balance the load # Check that sessions are distributed (may not be perfectly even due to least-loaded logic) replica_assignments = [ @@ -333,10 +333,10 @@ async def test_concurrent_operations(): # Concurrent operations tasks = [ - service.incr(session), # Session call - service.incr(session), # Session call - service.incr(), # Sessionless call - service.incr(), # Sessionless call + service.incr.choose(session), # Session call + service.incr.choose(session), # Session call + service.incr.choose(), # Sessionless call + service.incr.choose(), # Sessionless call ] await asyncio.gather(*tasks) @@ -355,3 +355,111 @@ async def test_concurrent_operations(): finally: await service.stop() + + +# `call` endpoint tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_basic(): + """Test basic broadcast call functionality.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=10) + + try: + # Test broadcast call to all replicas + results = await service.incr.call() + + # Should get results from all healthy replicas + assert isinstance(results, list) + assert len(results) == 3 # All 3 replicas should respond + + # All results should be None (incr doesn't return anything) + assert all(result is None for result in results) + + # Test getting values from all replicas + values = await service.value.call() + assert isinstance(values, list) + assert len(values) == 3 + + # All replicas should have incremented from 10 to 11 + assert all(value == 11 for value in values) + + finally: + await service.stop() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_with_failed_replica(): + """Test broadcast call behavior when some replicas fail.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # First, cause one replica to fail by calling fail_me on a specific session + session = await service.start_session() + try: + await service.fail_me.choose(session) + except RuntimeError: + pass # Expected failure + + # Wait briefly for replica to be marked as failed + await asyncio.sleep(0.1) + + # Now test broadcast call - should only hit healthy replicas + results = await service.incr.call() + + # Should get results from healthy replicas only + assert isinstance(results, list) + # Results length should match number of healthy replicas (2 out of 3) + healthy_count = len([r for r in service._replicas if r.healthy]) + assert len(results) == healthy_count + + # Get values from all healthy replicas + values = await service.value.call() + assert len(values) == healthy_count + + # All healthy replicas should have incremented to 1 + assert all(value == 1 for value in values) + + finally: + await service.stop() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_broadcast_call_vs_choose(): + """Test that broadcast call hits all replicas while choose hits only one.""" + cfg = ServiceConfig(procs_per_replica=1, num_replicas=3) + service = await spawn_service(service_cfg=cfg, actor_def=Counter, v=0) + + try: + # Use broadcast call to increment all replicas + await service.incr.call() + + # Get values from all replicas + values_after_broadcast = await service.value.call() + assert len(values_after_broadcast) == 3 + assert all(value == 1 for value in values_after_broadcast) + + # Use choose to increment only one replica + await service.incr.choose() + + # Get values again - one replica should be at 2, others at 1 + values_after_choose = await service.value.call() + assert len(values_after_choose) == 3 + assert sorted(values_after_choose) == [1, 1, 2] # One replica incremented twice + + # Verify metrics show the correct number of requests + metrics = service.get_metrics_summary() + total_requests = sum( + replica_metrics["total_requests"] + for replica_metrics in metrics["replicas"].values() + ) + # incr.call() (3 requests) + value.call() (3 requests) + incr.choose() (1 request) + value.call() (3 requests) = 10 total + assert total_requests == 10 + + finally: + await service.stop()