diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index 0fbce6a22..05bf1a318 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -7,7 +7,7 @@ from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics -from .service import Service, ServiceConfig +from .service import Service, ServiceActor, ServiceConfig from .spawn import shutdown_service, spawn_service __all__ = [ @@ -19,6 +19,7 @@ "ServiceMetrics", "Session", "SessionContext", + "ServiceActor", "spawn_service", "shutdown_service", ] diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index da614a16a..a8d844cb9 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -87,6 +87,35 @@ class ServiceEndpoint(Generic[P, R]): """ + def __init__(self, service, endpoint_name: str): + self.service = service + self.endpoint_name = endpoint_name + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.service.call_all(self.endpoint_name, *args, **kwargs) + return result + + +class ServiceEndpointV2(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + def __init__(self, actor_mesh, endpoint_name: str): self.actor_mesh = actor_mesh self.endpoint_name = endpoint_name @@ -108,6 +137,70 @@ async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: class ServiceInterface: + """ + A lightweight interface to the base Service class. + + This is a temporary workaround until Monarch supports nested + actors. + + """ + + def __init__(self, _service, actor_def): + self._service = _service + self.actor_def = actor_def + + # Dynamically create ServiceEndpoint objects for user's actor endpoints + # Inspect the actor_def directly to find endpoints + for attr_name in dir(actor_def): + attr_value = getattr(actor_def, attr_name) + if isinstance(attr_value, EndpointProperty): + # Create a ServiceEndpoint that will route through the Service Actor + endpoint = ServiceEndpoint(self._service, attr_name) + setattr(self, attr_name, endpoint) + + # Session management methods - handled by ServiceInterface + async def start_session(self) -> str: + """Starts a new session for stateful request handling.""" + return await self._service.start_session() + + async def terminate_session(self, sess_id: str): + """Terminates an active session and cleans up associated resources.""" + return await self._service.terminate_session(sess_id) + + def session(self) -> "SessionContext": + """Returns a context manager for session-based calls.""" + return SessionContext(self) + + async def get_metrics(self): + """Get comprehensive service metrics for monitoring and analysis.""" + return self._service.get_metrics() + + async def get_metrics_summary(self): + """Get a summary of key metrics for monitoring and debugging.""" + return self._service.get_metrics_summary() + + # Testing method - forwarded to Service Actor + async def _get_internal_state(self): + """ + Get comprehensive internal state for testing purposes. + + Returns: + dict: Complete internal state including sessions, replicas, and metrics + """ + return await self._service._get_internal_state() + + def __getattr__(self, name: str): + """Forward all other attribute access to the underlying Service Actor.""" + _service = object.__getattribute__(self, "_service") + # Forward everything else to the _service + if hasattr(_service, name): + return getattr(_service, name) + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + +class ServiceInterfaceV2: """ A lightweight interface to a Service Actor running on a single-node mesh. @@ -134,7 +227,7 @@ def __init__(self, _proc_mesh, _service, actor_def): attr_value = getattr(actor_def, attr_name) if isinstance(attr_value, EndpointProperty): # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpoint(self._service, attr_name) + endpoint = ServiceEndpointV2(self._service, attr_name) setattr(self, attr_name, endpoint) # Session management methods - handled by ServiceInterface @@ -160,23 +253,18 @@ async def get_metrics_summary(self): return await self._service.get_metrics_summary.call_one() # Testing method - forwarded to Service Actor - def _get_internal_state(self): + async def _get_internal_state(self): """ Get comprehensive internal state for testing purposes. Returns: dict: Complete internal state including sessions, replicas, and metrics """ - return self._service._get_internal_state.call_one() + return await self._service._get_internal_state.call_one() def __getattr__(self, name: str): """Forward all other attribute access to the underlying Service Actor.""" - try: - _service = object.__getattribute__(self, "_service") - except AttributeError: - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) + _service = object.__getattribute__(self, "_service") # Forward everything else to the _service if hasattr(_service, name): return getattr(_service, name) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 439d357dd..ede58c821 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -50,7 +50,563 @@ logger.setLevel(logging.DEBUG) -class Service(Actor): +class Service: + """ + Actor-less implementation of Service. + + This is a temporary solution to disable nested actors, which is proving + problematic temporarily. + + Args: + cfg: Service configuration including number of replicas, GPUs per replica, and health polling rate + actor_def: Actor class definition to instantiate on each replica + *actor_args: Positional arguments passed to actor constructor + **actor_kwargs: Keyword arguments passed to actor constructor + + Attributes: + _cfg: Service configuration + _replicas: List of managed replica instances + _active_sessions: Currently active sessions + _metrics: Aggregated service and replica metrics + _endpoints: Dynamically registered actor endpoints + """ + + def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): + self._cfg = cfg + self._replicas = [] + self._actor_def = actor_def + self._actor_kwargs = actor_kwargs + + self._active_sessions = [] + self._id_session_map = {} + self._session_replica_map: Dict[str, int] = {} + self._next_replica_idx = 0 # For round-robin load balancing + + # Initialize metrics collection + self._metrics = ServiceMetrics() + self._health_task = None + self._shutdown_requested = False + + # Replica initialization queue + self._replicas_to_recover = [] + + async def __initialize__(self): + """Initializes the service and starts the health loop.""" + logger.debug(f"Starting service up with {self._cfg.num_replicas} replicas.") + replicas = [] + num_replicas = self._cfg.num_replicas + for i in range(num_replicas): + replica = Replica( + idx=len(self._replicas) + i, + proc_config=self._cfg.to_process_config(), + max_concurrent_requests=self._cfg.replica_max_concurrent_requests, + return_first_rank_result=self._cfg.return_first_rank_result, + actor_def=self._actor_def, + actor_kwargs=self._actor_kwargs, + ) + replicas.append(replica) + + logger.debug( + f"Queued {num_replicas} replicas for initialization. Total replicas: {len(self._replicas)}" + ) + + # Initialize all replicas in parallel + await asyncio.gather(*[r.initialize() for r in replicas]) + self._replicas = replicas + + # Start the health loop in the background + self._health_task = asyncio.create_task( + self._health_loop(poll_rate_s=self._cfg.health_poll_rate) + ) + + async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + """ + Routes a function call to the appropriate replica with load balancing and fault tolerance. + + This is the core routing method that handles: + - Session-based routing for stateful calls + - Round-robin load balancing for stateless calls + - Custom routing based on context hints + - Automatic retry on replica failures + - Request queuing and processing + + Args: + sess_id: Optional session ID for stateful routing + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint + + Returns: + The result from the actor endpoint execution + + Raises: + RuntimeError: If no healthy replicas are available + Exception: Any exception raised by the actor endpoint + """ + # Check context variables for session state if no explicit sess_id + if sess_id is None: + ctx = _session_context.get(None) + if ctx: + sess_id = ctx["session_id"] + + replica = await self._get_replica(sess_id) + + # Create a ServiceRequest object to queue + request = ServiceRequest( + session_id=sess_id, + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + + # Queue the request using replica's method + await replica.enqueue_request(request) + + # Wait for the result + try: + return await request.future + except Exception as e: + # If the replica failed, try to retry once + if not replica.healthy: + logger.debug( + f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" + ) + return await self._retry_request_on_healthy_replica( + sess_id, function, *args, **kwargs + ) + raise + + async def call_all(self, function: str, *args, **kwargs) -> List: + """ + Broadcasts a function call to all healthy replicas and returns results as a list. + + Args: + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint + + Returns: + List of results from all healthy replicas + + Raises: + RuntimeError: If no healthy replicas are available + """ + healthy_replicas = [r for r in self._replicas if r.healthy] + + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for broadcast call") + + # Create requests for all healthy replicas + requests = [] + for replica in healthy_replicas: + request = ServiceRequest( + session_id=None, # Broadcast calls don't use sessions + function=function, + args=args, + kwargs=kwargs, + future=asyncio.Future(), + ) + requests.append((replica, request)) + + # Enqueue all requests + for replica, request in requests: + await replica.enqueue_request(request) + + # Wait for all results + results = [] + for replica, request in requests: + try: + result = await request.future + results.append(result) + except Exception as e: + logger.warning( + f"Request to replica {replica.idx} failed during broadcast: {e}" + ) + # Add None for failed replicas to maintain indexing + results.append(None) + + return results + + async def _retry_request_on_healthy_replica( + self, sess_id: str | None, function: str, *args, **kwargs + ): + """Retries a failed request on a healthy replica.""" + # Force reassignment to a healthy replica (only for session-based calls) + if sess_id is not None and sess_id in self._session_replica_map: + del self._session_replica_map[sess_id] + + # Retry the call (this will assign to a new healthy replica) + return await self._call(sess_id, function, *args, **kwargs) + + async def _migrate_remaining_requests(self, failed_replica: Replica): + """Migrates remaining requests from a failed replica to healthy replicas.""" + migrated_requests = [] + + # Collect all remaining requests + while not failed_replica.request_queue.empty(): + try: + request = failed_replica.request_queue.get_nowait() + migrated_requests.append(request) + except asyncio.QueueEmpty: + break + + if not migrated_requests: + return + + logger.debug( + "Migrating %d requests from failed replica %d", + len(migrated_requests), + failed_replica.idx, + ) + + # Find healthy replicas + healthy_replicas = [ + r for r in self._replicas if r.healthy and r != failed_replica + ] + + if not healthy_replicas: + # No healthy replicas, fail all requests + for request in migrated_requests: + request.future.set_exception( + RuntimeError("No healthy replicas available") + ) + return + + # Distribute requests among healthy replicas + for i, request in enumerate(migrated_requests): + target_replica = healthy_replicas[i % len(healthy_replicas)] + await target_replica.enqueue_request(request) + + # Update session mapping if needed + sess_id = request.session_id + if ( + sess_id in self._session_replica_map + and self._session_replica_map[sess_id] == failed_replica.idx + ): + self._session_replica_map[sess_id] = target_replica.idx + + async def start_session(self) -> str: + """ + Starts a new session for stateful request handling. + + Sessions enable request affinity to specific replicas, maintaining state + consistency for workloads that require it. Each session gets a unique ID + and is automatically assigned to the least loaded replica. + + Returns: + str: Unique session identifier for use in subsequent requests + + Example: + >>> session_id = await service.start_session() + >>> result = await service.my_endpoint(session_id, arg1, arg2) + >>> await service.terminate_session(session_id) + """ + sess_id = str(uuid.uuid4()) + session = Session(session_id=sess_id) + self._active_sessions.append(session) + + # Update metrics + self._update_service_metrics() + + return sess_id + + def _update_service_metrics(self): + """Updates service-level metrics.""" + self._metrics.total_sessions = len(self._active_sessions) + self._metrics.total_replicas = len(self._replicas) + self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + # Store direct references to replica metrics for aggregation + self._metrics.replica_metrics = {} + for replica in self._replicas: + # Use the replica's own metrics directly + self._metrics.replica_metrics[replica.idx] = replica.metrics + + def get_metrics(self) -> ServiceMetrics: + """ + Get comprehensive service metrics for monitoring and analysis. + + Returns detailed metrics including per-replica performance data, + service-wide aggregations, and health status information. + + Returns: + ServiceMetrics: Complete metrics object with replica and service data + + Example: + >>> metrics = service.get_metrics() + >>> print(f"Request rate: {metrics.get_total_request_rate():.1f} req/s") + >>> print(f"Queue depth: {metrics.get_avg_queue_depth():.1f}") + """ + self._update_service_metrics() + return self._metrics + + def get_metrics_summary(self) -> dict: + """ + Get a summary of key metrics for monitoring and debugging. + + Provides a structured summary of service and replica metrics in a format + suitable for monitoring dashboards, logging, or debugging purposes. + + Returns: + dict: Structured metrics summary with service and per-replica data + + Example: + >>> summary = service.get_metrics_summary() + >>> print(f"Healthy replicas: {summary['service']['healthy_replicas']}") + >>> for idx, metrics in summary['replicas'].items(): + ... print(f"Replica {idx}: {metrics['request_rate']:.1f} req/s") + """ + self._update_service_metrics() + + summary = { + "service": { + "total_sessions": self._metrics.total_sessions, + "healthy_replicas": self._metrics.healthy_replicas, + "total_replicas": self._metrics.total_replicas, + "total_request_rate": self._metrics.get_total_request_rate(), + "avg_queue_depth": self._metrics.get_avg_queue_depth(self._replicas), + "avg_capacity_utilization": self._metrics.get_avg_capacity_utilization( + self._replicas + ), + "sessions_per_replica": self._metrics.get_sessions_per_replica(), + }, + "replicas": {}, + } + + for replica in self._replicas: + metrics = replica.metrics + + # Count sessions assigned to this replica + assigned_sessions = sum( + 1 + for replica_idx in self._session_replica_map.values() + if replica_idx == replica.idx + ) + + summary["replicas"][replica.idx] = { + "total_requests": metrics.total_requests, + "successful_requests": metrics.successful_requests, + "failed_requests": metrics.failed_requests, + "request_rate": metrics.get_request_rate(), + "avg_latency": metrics.get_avg_latency(), + "active_requests": replica.active_requests, # Get from replica + "queue_depth": replica.qsize(), + "assigned_sessions": assigned_sessions, # Calculate from session map + "capacity_utilization": replica.capacity_utilization, # Get from replica + } + + return summary + + async def terminate_session(self, sess_id: str): + """ + Terminates an active session and cleans up associated resources. + + Removes the session from active tracking, clears replica assignments, + and updates service metrics. Sessions should be terminated when no + longer needed to free up resources. + + Args: + sess_id: The unique session identifier to terminate + + Example: + >>> session_id = await service.start_session() + >>> # ... use session for requests ... + >>> await service.terminate_session(session_id) + """ + logger.debug("Terminating session %s", sess_id) + + # Remove from active sessions + self._active_sessions = [ + s for s in self._active_sessions if s.session_id != sess_id + ] + + # Remove from session-replica mapping + if sess_id in self._session_replica_map: + del self._session_replica_map[sess_id] + + # Update metrics + self._update_service_metrics() + + async def _health_loop(self, poll_rate_s: float): + """Runs the health loop to monitor and recover replicas. + + This loop continuously checks the health of replicas and recovers + failed replicas by reinitializing their proc_meshes. It also + periodically updates service metrics to reflect the current state. + + """ + while not self._shutdown_requested: + # Process any replicas that need recovery + await self._recover_replicas() + + # Check for failed replicas and recover them + failed_replicas = [] + for replica in self._replicas: + if replica.failed: + failed_replicas.append(replica) + + if any(failed_replicas): + logger.debug( + "[HEALTH LOOP] Detected %d failed replicas: %s", + len(failed_replicas), + pprint.pformat(failed_replicas), + ) + self._replicas_to_recover.extend(failed_replicas) + + await asyncio.sleep(poll_rate_s) + + def _get_next_replica(self) -> "Replica": + """Get the next replica using round-robin selection.""" + healthy_replicas = [r for r in self._replicas if r.healthy] + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for load balancing") + + # Simple round-robin + self._next_replica_idx = (self._next_replica_idx + 1) % len(healthy_replicas) + return healthy_replicas[self._next_replica_idx] + + def _get_least_loaded_replica(self) -> "Replica": + """Get the replica with the lowest load.""" + healthy_replicas = [r for r in self._replicas if r.healthy] + if not healthy_replicas: + raise RuntimeError("No healthy replicas available for session assignment") + + # Use the replica's current_load property + return min(healthy_replicas, key=lambda replica: replica.current_load) + + async def _get_replica(self, sess_id: str | None) -> "Replica": + """Get a replica for the given session ID.""" + if sess_id is None: + # No session, use round-robin load balancing + replica = self._get_next_replica() + return replica + + # Session-based routing + if sess_id in self._session_replica_map: + replica_idx = self._session_replica_map[sess_id] + # Find the replica with this index + for replica in self._replicas: + if replica.idx == replica_idx and replica.healthy: + return replica + # If the replica is no longer healthy, remove from session map and reassign + del self._session_replica_map[sess_id] + + # New session, assign to least loaded replica + replica = self._get_least_loaded_replica() + self._session_replica_map[sess_id] = replica.idx + logger.debug("Assigning session %s to replica %d", sess_id, replica.idx) + return replica + + async def stop(self): + logger.debug("Stopping service...") + # Signal shutdown to health loop + self._shutdown_requested = True + + # Wait for health loop to finish gracefully + if self._health_task is not None: + try: + await asyncio.wait_for(self._health_task, timeout=5.0) + logger.info("Health loop stopped gracefully.") + except asyncio.TimeoutError: + logger.warning("Health loop didn't stop gracefully, cancelling...") + self._health_task.cancel() + try: + await self._health_task + except asyncio.CancelledError: + logger.info("Health loop task cancelled.") + + # Stop all replicas using their stop method + await asyncio.gather( + *[replica.stop() for replica in self._replicas], + return_exceptions=True, + ) + + async def _recover_replicas(self): + """Recovers unhealthy queued replicas.""" + if not self._replicas_to_recover: + return + + logger.debug( + "Recovering replicas: %s", pprint.pformat(self._replicas_to_recover) + ) + + async def _recover(replica): + """Recover a single replica.""" + try: + await replica.recover() + logger.debug("Successfully recovered replica %d", replica.idx) + except Exception as e: + logger.error("Failed to recover replica %d: %s", replica.idx, e) + replica.mark_failed() + + recovery_tasks = [ + asyncio.create_task(_recover(replica)) + for replica in self._replicas_to_recover + ] + + await asyncio.gather(*recovery_tasks, return_exceptions=True) + self._replicas_to_recover.clear() + + async def _migrate_replica_workload(self, replica_to_remove: Replica): + """Migrates all workload from a replica that's being removed.""" + # Migrate queued requests + await self._migrate_remaining_requests(replica_to_remove) + + # Reassign sessions to other replicas + sessions_to_reassign = [ + sess_id + for sess_id, replica_idx in self._session_replica_map.items() + if replica_idx == replica_to_remove.idx + ] + + for sess_id in sessions_to_reassign: + del self._session_replica_map[sess_id] + logger.debug("Session %s will be reassigned on next request", sess_id) + + async def _get_internal_state(self) -> dict: + """ + Gets comprehensive internal state for testing purposes. + + This is intended for testing/debugging only, it should not + be relied upon in actual production code. + """ + # Ensure metrics are up to date + self._update_service_metrics() + + return { + # Session management state + "session_replica_map": dict(self._session_replica_map), # Copy for safety + "active_sessions": [s.session_id for s in self._active_sessions], + "id_session_map": dict(self._id_session_map), # Copy for safety + # Replica state + "replicas": [ + { + "idx": replica.idx, + "state": replica.state.value, + "healthy": replica.healthy, + "failed": replica.failed, + "active_requests": replica.active_requests, + "queue_size": replica.request_queue.qsize(), + "capacity_utilization": replica.capacity_utilization, + } + for replica in self._replicas + ], + # Load balancing state + "next_replica_idx": self._next_replica_idx, + # Service-level state + "total_replicas": len(self._replicas), + "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), + "shutdown_requested": self._shutdown_requested, + # Metrics summary + "total_sessions": len(self._active_sessions), + "replica_count": len(self._replicas), + } + + def __repr__(self): + return f"Service(actor={self._actor_def.__name__})" + + +class ServiceActor(Actor): """ Distributed Actor Service Controller @@ -72,6 +628,7 @@ class Service(Actor): """ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): + self._service = Service() self._cfg = cfg self._replicas = [] self._actor_def = actor_def diff --git a/src/forge/controller/service/spawn.py b/src/forge/controller/service/spawn.py index b8bd3b5bc..554f1df0e 100644 --- a/src/forge/controller/service/spawn.py +++ b/src/forge/controller/service/spawn.py @@ -11,9 +11,9 @@ from monarch.actor import proc_mesh from forge.controller import ForgeActor -from forge.controller.service import Service, ServiceConfig +from forge.controller.service import Service, ServiceActor, ServiceConfig -from forge.controller.service.interface import ServiceInterface +from forge.controller.service.interface import ServiceInterface, ServiceInterfaceV2 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -24,6 +24,42 @@ async def spawn_service( ) -> ServiceInterface: """Spawns a service based on the actor class. + Args: + service_cfg: Service configuration + actor_def: Actor class definition + **actor_kwargs: Keyword arguments to pass to actor constructor + + Returns: + A ServiceInterface that provides access to the Service Actor + """ + # Assert that actor_def is a subclass of ForgeActor + if not issubclass(actor_def, ForgeActor): + raise TypeError( + f"actor_def must be a subclass of ForgeActor, got {type(actor_def).__name__}" + ) + + # Create a single-node proc_mesh and actor_mesh for the Service Actor + logger.info("Spawning Service Actor for %s", actor_def.__name__) + service = Service(service_cfg, actor_def, actor_kwargs) + await service.__initialize__() + # Return the ServiceInterface that wraps the proc_mesh, actor_mesh, and actor_def + return ServiceInterface(service, actor_def) + + +async def shutdown_service(service: ServiceInterface) -> None: + """Shuts down the service. + + Implemented in this way to avoid actors overriding stop() unintentionally. + + """ + await service._service.stop() + + +async def spawn_service_v2( + service_cfg: ServiceConfig, actor_def: Type[ForgeActor], **actor_kwargs +) -> ServiceInterfaceV2: + """Spawns a service based on the actor class. + Args: service_cfg: Service configuration actor_def: Actor class definition @@ -42,15 +78,15 @@ async def spawn_service( logger.info("Spawning Service Actor for %s", actor_def.__name__) m = await proc_mesh(gpus=1) service_actor = await m.spawn( - "service", Service, service_cfg, actor_def, actor_kwargs + "service", ServiceActor, service_cfg, actor_def, actor_kwargs ) await service_actor.__initialize__.call_one() # Return the ServiceInterface that wraps the proc_mesh, actor_mesh, and actor_def - return ServiceInterface(m, service_actor, actor_def) + return ServiceInterfaceV2(m, service_actor, actor_def) -async def shutdown_service(service: ServiceInterface) -> None: +async def shutdown_service_v2(service: ServiceInterfaceV2) -> None: """Shuts down the service. Implemented in this way to avoid actors overriding stop() unintentionally.