diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index f0d8fca7b..d745f1556 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .endpoint import service_endpoint, ServiceEndpointProperty from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState -from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter +from .router import Batcher, LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ @@ -24,4 +25,8 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", + "service_endpoint", + "ServiceEndpointProperty", + "Router", + "Batcher", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py new file mode 100644 index 000000000..c731017c6 --- /dev/null +++ b/src/forge/controller/service/endpoint.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Service endpoint management for the Forge framework. +""" + +from typing import Any, Callable, Generic, List, TypeVar + +from monarch._src.actor.endpoint import EndpointProperty + +from typing_extensions import ParamSpec + +from .router import RoundRobinRouter, Router + +P = ParamSpec("P") +R = TypeVar("R") +Propagator = Any + + +class ServiceEndpoint(Generic[P, R]): + """ + This extends Monarch's actor APIs for service endpoints. + - `route(*args, **kwargs)`: Routes the request to a single replica. + - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. + + Monarch's native actor APIs do not apply for services. + """ + + def __init__( + self, + service, + endpoint_name: str, + ): + self.service = service + self.endpoint_name = endpoint_name + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.service._route(sess_id, self.endpoint_name, *args, **kwargs) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.service._fanout(self.endpoint_name, *args, **kwargs) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class ServiceEndpointV2(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, actor_mesh, endpoint_name: str): + self.actor_mesh = actor_mesh + self.endpoint_name = endpoint_name + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.actor_mesh.call.call_one( + sess_id, self.endpoint_name, *args, **kwargs + ) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.actor_mesh.call_all.call_one( + self.endpoint_name, *args, **kwargs + ) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class ServiceEndpointProperty(EndpointProperty, Generic[P, R]): + """ + Extension of EndpointProperty that carries service-specific + routing and batching configuration. + + Inherits from EndpointProperty so the method is still registered as + a valid actor endpoint, while also attaching service-specific options + (router, batch_size, batch_timeout). + """ + + def __init__( + self, + method: Any, + propagator: Propagator, + explicit_response_port: bool, + *, + router: Callable[[], Router] = RoundRobinRouter, + batch_size: int = 1, + batch_timeout: float = 0.01, + ) -> None: + super().__init__(method, propagator, explicit_response_port) + self.router = router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + + +def service_endpoint( + *, + router: Callable[[], Router] = RoundRobinRouter, + batch_size: int = 1, + batch_timeout: float = 0.01, + propagate=None, + explicit_response_port=False, +): + """ + Marks an actor method as a service endpoint with batching routing support. + + Example: + class MyForgeActor(ForgeActor): + @service_endpoint(router=RoundRobinRouter(), batch_size=16, batch_timeout=0.05) + async def predict(self, x): ... + """ + + def decorator(method) -> ServiceEndpointProperty: + return ServiceEndpointProperty( + method, + propagator=propagate, + explicit_response_port=explicit_response_port, + router=router, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + + return decorator diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 5b7e2f884..8e4a0c461 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -11,20 +11,11 @@ """ import contextvars -import logging -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty -from .replica import Replica - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -P = ParamSpec("P") -R = TypeVar("R") +from .endpoint import ServiceEndpoint, ServiceEndpointProperty, ServiceEndpointV2 @dataclass @@ -77,94 +68,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.session_id = None -class ServiceEndpoint(Generic[P, R]): - """ - This extends Monarch's actor APIs for service endpoints. - - `route(*args, **kwargs)`: Routes the request to a single replica. - - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. - - Monarch's native actor APIs do not apply for services. - """ - - def __init__(self, service, endpoint_name: str): - self.service = service - self.endpoint_name = endpoint_name - - async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) - - async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.service.call_all(self.endpoint_name, *args, **kwargs) - return result - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use choose() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use call() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use a call_one() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use broadcast() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def generate(self, *args: P.args, **kwargs: P.kwargs): - raise NotImplementedError( - "You tried to use generate() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - -class ServiceEndpointV2(Generic[P, R]): - """An endpoint object specific to services. - - This loosely mimics the Endpoint APIs exposed in Monarch, with - a few key differences: - - Only choose and call are retained (dropping stream and call_one) - - Call returns a list directly rather than a ValueMesh. - - These changes are made with Forge use cases in mind, but can - certainly be expanded/adapted in the future. - - """ - - def __init__(self, actor_mesh, endpoint_name: str): - self.actor_mesh = actor_mesh - self.endpoint_name = endpoint_name - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.actor_mesh.call.call_one( - sess_id, self.endpoint_name, *args, **kwargs - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.actor_mesh.call_all.call_one( - self.endpoint_name, *args, **kwargs - ) - return result - - class ServiceInterface: """ A lightweight interface to the base Service class. @@ -182,10 +85,15 @@ def __init__(self, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpoint(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpoint(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -260,10 +168,15 @@ def __init__(self, _proc_mesh, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpointV2(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpointV2(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -306,17 +219,3 @@ def __getattr__(self, name: str): raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) - - -class Router(ABC): - """Abstract base class for routing logic.""" - - @abstractmethod - def get_replica( - self, - healthy_replicas: List[Replica], - sess_id: str | None = None, - session_map: Dict[str, int] | None = None, - ) -> Replica: - """Select a replica from the list based on routing logic.""" - pass diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index ae69d3df5..81c6de0e1 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -220,6 +220,11 @@ async def enqueue_request(self, request: ServiceRequest): # Accept requests in all other states - let the processing loop handle the rest await self.request_queue.put(request) + async def enqueue_batch(self, requests: list[ServiceRequest]): + """Enqueues a batch of requests for processing by this replica.""" + for req in requests: + await self.enqueue_request(req) + async def _process_single_request(self, request: ServiceRequest) -> bool: """Processes a single request and returns success status. diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 502402e36..11311f93f 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -4,16 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +import asyncio import logging -from typing import Dict, List +import time +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List -from .interface import Router -from .replica import Replica +from forge.controller.service.replica import Replica, ServiceRequest logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +class Router(ABC): + """Abstract base class for routing logic.""" + + @abstractmethod + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + """Select a replica from the list based on routing logic.""" + pass + + class RoundRobinRouter(Router): """Round-robin router for stateless requests.""" @@ -88,3 +105,128 @@ def get_replica( replica.idx, ) return replica + + +class Batcher: + """ + Asynchronous batching wrapper around a Router. + + Instead of selecting a replica immediately, incoming requests are enqueued + and grouped into batches. Once a batch is ready (either reaching the maximum + size or exceeding the maximum wait time), the batcher makes a single routing + decision using the inner router. All requests in that batch are then resolved + with the same replica. + + This reduces router overhead by amortizing multiple requests into one decision. + + Args: + inner_router: The underlying Router used to pick a replica. + get_healthy_replicas: Callable that returns the current list of healthy replicas. + get_session_map: Callable that returns the session-to-replica mapping. + batch_size: Maximum number of requests to collect in a single batch + before routing (default: 8). + batch_timeout: Maximum time to wait (in seconds) before routing a batch, + even if batch_size is not reached (default: 0.01). + + Example: + rr_router = RoundRobinRouter() + batcher = Batcher( + rr_router, + get_healthy_replicas=service._get_healthy_replicas, + get_session_map=service._get_session_map, + batch_size=16, + batch_timeout=0.01, + ) + + request = ServiceRequest(...) + + # Enqueue a request to be sent to a replica + await batcher.enqueue(request) + """ + + def __init__( + self, + inner_router: Router, + get_healthy_replicas: Callable[[], List["Replica"]], + get_session_map: Callable[[], Dict[str, int]], + batch_size: int = 16, + batch_timeout: float = 0.01, + ): + + self.inner_router = inner_router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.get_healthy_replicas = get_healthy_replicas + self.get_session_map = get_session_map + + # Internal queue for batching routing requests + self._queue: asyncio.Queue = asyncio.Queue() + self._running = True # flag to control loop + # Background task that processes batches continuously + self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_size or batch_timeout is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running: + + # Wait for first request + batch = [await self._queue.get()] + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self.batch_timeout - (time.monotonic() - start_time) + ) + req = await asyncio.wait_for( + self._queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch.append(req) + + if len(batch) >= self.batch_size: + break + except asyncio.TimeoutError: + break + + session_map = self.get_session_map() + healthy_replicas = self.get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.inner_router.get_replica(healthy_replicas, None, session_map) + + # Send whole batch to replica + try: + await replica.enqueue_batch(batch) + except Exception as e: + for req in batch: + req.future.set_exception(e) + + async def enqueue(self, request: ServiceRequest) -> Any: + """Enqueue request and wait until batch assigns a replica.""" + # Queue the request for batching - this is non-blocking + self._queue.put_nowait(request) + + # Wait for the batch processor to resolve our future + return await request.future + + async def stop(self): + """Stop the batch loop gracefully.""" + self._running = False + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 0b655fb6a..6c182b5ae 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -40,14 +40,16 @@ from monarch.actor import Actor, endpoint +from forge.controller.service.endpoint import ServiceEndpointProperty + from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest - from forge.controller.service.router import ( - LeastLoadedRouter, + Batcher, RoundRobinRouter, + Router, SessionRouter, ) from forge.types import ServiceConfig @@ -108,7 +110,10 @@ async def __initialize__(self): # Initialize the routers self._default_router = RoundRobinRouter() - self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) + self._session_router = SessionRouter(fallback_router=self._default_router) + + # This keeps the map between the registered endpoints and the routers + self.routers: dict[str, Router | Batcher] = {} # Initialize all replicas replicas = [] @@ -138,14 +143,66 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + def _set_router( + self, endpoint_name: str, prop: ServiceEndpointProperty | None = None + ) -> None: + """ + Ensure a router exists for the given endpoint. + + - If a router is already set, raise an error. + - If a ServiceEndpointProperty is provided, construct its router/batcher + using the specified configuration. + - If not provided, fall back to a default round-robin router. + + Args: + endpoint_name: Name of the endpoint. + prop: Optional ServiceEndpointProperty object with router, batch_size, + and batch_timeout attributes. + """ + + # If router already exists, raise an exception + if endpoint_name in self.routers: + raise AssertionError(f"Router already exists for endpoint: {endpoint_name}") + + # If config is missing or incomplete, use default router + if prop is None or not isinstance(prop, ServiceEndpointProperty): + return + + # Resolve base router + if not callable(prop.router): + raise ValueError(f"Router must be callable, got: {prop.router}") + else: + base_router = prop.router() # Call the router constructor + batch_size = prop.batch_size + batch_timeout = prop.batch_timeout + + if not isinstance(base_router, Router): + raise ValueError( + f"Router must be a Router instance, got: {type(base_router)}" + ) + + # Wrap in Batcher if batching requested + if batch_size > 1: + router = Batcher( + base_router, + get_healthy_replicas=self._get_healthy_replicas, + get_session_map=self._get_session_map, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + else: + router = base_router + + # Store and return + self.routers[endpoint_name] = router + + async def _route(self, sess_id: str | None, function: str, *args, **kwargs): """ Routes a function call to the appropriate replica with load balancing and fault tolerance. This is the core routing method that handles: - Session-based routing for stateful calls - - Round-robin load balancing for stateless calls - - Custom routing based on context hints + - Per-endpoint router selection (round-robin, least-loaded, batching, etc.) - Automatic retry on replica failures - Request queuing and processing @@ -168,7 +225,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): if ctx: sess_id = ctx["session_id"] - replica = await self._get_replica(sess_id) + router = self.routers.get(function, self._default_router) # Create a ServiceRequest object to queue request = ServiceRequest( @@ -179,8 +236,27 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): future=asyncio.Future(), ) - # Queue the request using replica's method - await replica.enqueue_request(request) + # Case: batching is enabled and no session ID (stateless calls only) + # Forward the request into the Batcher queue. The batcher will send + # the batched requests to the selected replica. + if sess_id is None and isinstance(router, Batcher): + await router.enqueue(request) + + else: + healthy_replicas = [r for r in self._replicas if r.healthy] + + # Select replica: + if sess_id is not None: + # Case 1: sticky sessions + replica = self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + else: + # Case 2: stateless routing + replica = self._default_router.get_replica(healthy_replicas) + + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: @@ -196,7 +272,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - async def call_all(self, function: str, *args, **kwargs) -> List: + async def _fanout(self, function: str, *args, **kwargs) -> List: """ Broadcasts a function call to all healthy replicas and returns results as a list. @@ -211,7 +287,7 @@ async def call_all(self, function: str, *args, **kwargs) -> List: Raises: RuntimeError: If no healthy replicas are available """ - healthy_replicas = [r for r in self._replicas if r.healthy] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: raise RuntimeError("No healthy replicas available for broadcast call") @@ -256,7 +332,7 @@ async def _retry_request_on_healthy_replica( del self._session_replica_map[sess_id] # Retry the call (this will assign to a new healthy replica) - return await self._call(sess_id, function, *args, **kwargs) + return await self._route(sess_id, function, *args, **kwargs) async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" @@ -280,9 +356,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): ) # Find healthy replicas - healthy_replicas = [ - r for r in self._replicas if r.healthy and r != failed_replica - ] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: # No healthy replicas, fail all requests @@ -334,7 +408,7 @@ def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + self._metrics.healthy_replicas = len(self._get_healthy_replicas()) # Store direct references to replica metrics for aggregation self._metrics.replica_metrics = {} for replica in self._replicas: @@ -446,6 +520,13 @@ async def terminate_session(self, sess_id: str): # Update metrics self._update_service_metrics() + def _get_healthy_replicas(self) -> list[Replica]: + """Returns a list of healthy replicas.""" + return [r for r in self._replicas if r.healthy] + + def _get_session_map(self) -> Dict[str, int]: + return self._session_replica_map + async def _health_loop(self, poll_rate_s: float): """Runs the health loop to monitor and recover replicas. @@ -474,17 +555,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _get_replica(self, sess_id: str | None) -> "Replica": - """Get a replica for the given session ID.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if sess_id is None: - # No session, use the default router - return self._default_router.get_replica(healthy_replicas) - - return self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop @@ -503,6 +573,15 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all batchers in routers + # Stop all batchers + batchers = [ + router for router in self.routers.values() if isinstance(router, Batcher) + ] + if batchers: + await asyncio.gather(*(b.stop() for b in batchers), return_exceptions=True) + logger.info("All batcher loop(s) stopped gracefully.") + # Stop all replicas using their stop method await asyncio.gather( *[replica.stop() for replica in self._replicas], @@ -582,7 +661,7 @@ async def _get_internal_state(self) -> dict: # Load balancing state # Service-level state "total_replicas": len(self._replicas), - "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), + "healthy_replica_count": len(self._get_healthy_replicas()), "shutdown_requested": self._shutdown_requested, # Metrics summary "total_sessions": len(self._active_sessions), diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py new file mode 100644 index 000000000..4ce497e77 --- /dev/null +++ b/tests/unit_tests/test_router.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Tests for router.py +""" + +import asyncio +import logging + +import pytest +from forge.controller import ForgeActor +from forge.controller.service import ( + Batcher, + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + service_endpoint, + SessionRouter, +) + +from forge.types import ProcessConfig +from monarch.actor import endpoint + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Counter(ForgeActor): + """Test actor that maintains a counter with various endpoints.""" + + def __init__(self, v: int): + self.v = v + + @endpoint + async def value(self) -> int: + """Get the current counter value.""" + return self.v + + @endpoint + async def fail_me(self): + """Endpoint that always fails to test error handling.""" + raise RuntimeError("I was asked to fail") + + @endpoint + async def add_to_value(self, amount: int, multiplier: int = 1) -> int: + """Add an amount (optionally multiplied) to the current value.""" + logger.info(f"adding {amount} with {multiplier}") + self.v += amount * multiplier + return self.v + + @endpoint + async def incr(self): + """Increment the counter.""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter, batch_size=3, batch_timeout=1) + async def rr_batch_incr_bsize3(self): + """Increment the round-robin counter with batching (batch size = 3).""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) + async def rr_batch_incr_bsize5(self): + """Increment the round-robin counter with batching (batch size = 5).""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter) + async def rr_batch_incr_bsize1(self): + """Increment the round-robin counter with default batch_size=1.""" + self.v += 1 + + +def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: + """Helper to build a replica with specified state and load.""" + replica = Replica( + idx=idx, + proc_config=ProcessConfig(), + actor_def=Counter, + actor_args=(), + actor_kwargs={}, + ) + replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY + replica.active_requests = load + return replica + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_as_actor_preserves_normal_usage(): + """Ensure that using `as_actor` does not break normal semantics.""" + service = await Counter.as_actor(5) + + try: + assert await service.value.choose() == 5 + + # Test increment + await service.rr_batch_incr_bsize3.choose() + assert await service.value.choose() == 6 + + finally: + await Counter.shutdown(service) + + +@pytest.mark.asyncio +async def test_service_endpoint_router_and_configurations(): + """ + Verify service endpoints are registered with correct router/batching configuration: + - rr_batch_incr_bsize1: plain RoundRobinRouter, no batching (batch_size=1, timeout=0.01) + - rr_batch_incr_bsize3: Batcher wrapping RoundRobinRouter (batch_size=3, timeout=1) + - incr: plain @endpoint, should not appear in service.routers + """ + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # --- rr_batch_incr_bsize1 --- + router1 = service.routers.get("rr_batch_incr_bsize1") + assert isinstance( + router1, RoundRobinRouter + ), f"Expected RoundRobinRouter, got {type(router1)}" + + prop1 = Counter.rr_batch_incr_bsize1 + assert prop1.batch_size == 1 + assert prop1.batch_timeout == 0.01 + + # --- rr_batch_incr_bsize3 --- + router3 = service.routers.get("rr_batch_incr_bsize3") + + assert isinstance(router3, Batcher), f"Expected Batcher, got {type(router3)}" + assert router3.batch_size == 3 + assert router3.batch_timeout == 1 + + # --- incr --- + assert ( + "incr" not in service.routers + ), "Plain @endpoint should not be in service.routers" + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_noncallable(): + """@service_endpoint with non-callable router should raise ValueError.""" + + class BadActor(ForgeActor): + @service_endpoint(router="roundrobin") # string, not callable + async def bad_endpoint(self): + return 42 + + with pytest.raises(ValueError, match="Router must be callable"): + # Triggers ServiceInterface._set_router during construction + await BadActor.options(num_replicas=1).as_service() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_wrong_return_type(): + """@service_endpoint with callable that doesn't return Router should raise ValueError.""" + + class NotARouter: + """Dummy class that is not a Router.""" + + class BadActor(ForgeActor): + @service_endpoint(router=NotARouter) # returns NotARouter + async def bad_endpoint(self): + return 123 + + with pytest.raises(ValueError, match="Router must be a Router instance"): + await BadActor.options(num_replicas=1).as_service() + + +# Router Tests + + +@pytest.mark.asyncio +async def test_session_router_fallback_rr_vs_ll(): + """Switch fallback router to round-robin and verify assignment order.""" + # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = RoundRobinRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx != r2.idx + assert set(session_map.values()) == {0, 1} + + # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx == r2.idx == 0 + + +# Router integeration tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + for _ in range(6): + await service.incr.route() + values = await service.value.fanout() + results.append(values) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + final_values = results[-1] # last snapshot + assert sorted(final_values) == [2, 2, 2] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution_with_batching(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas with batch routing.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + tasks = [service.rr_batch_incr_bsize3.route() for _ in range(6)] + await asyncio.gather(*tasks) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + values = await service.value.fanout() + assert sorted(values) == [0, 3, 3] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map_in_service(): + """Integration: Service with SessionRouter preserves sticky sessions.""" + service = await Counter.options( + procs=1, + num_replicas=2, + ).as_service(v=0) + + try: + # First call with sess_id -> assign a replica + await service.incr.route(sess_id="sess1") + values1 = await service.value.fanout() + + # Second call with same sess_id -> must hit same replica + await service.incr.route(sess_id="sess1") + values2 = await service.value.fanout() + + # Difference should only be on one replica (sticky session) + diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] + assert ( + sum(diffs) == 1 + ), f"Expected exactly one replica to increment, got {diffs}" + assert max(diffs) == 1 and min(diffs) == 0 + + # Session map in service should reflect assigned replica + assigned_idx = service._session_replica_map["sess1"] + assert values2[assigned_idx] == values1[assigned_idx] + 1 + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_independent_batchers_and_routers_per_endpoint(): + """Ensure multiple @service_endpoint endpoints coexist with independent routers/batchers.""" + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # --- First batch: rr_batch_incr_bsize3 (batch_size = 3) --- + tasks = [ + asyncio.create_task(service.rr_batch_incr_bsize3.route()) for _ in range(4) + ] + await asyncio.gather(*tasks) + + values = await service.value.fanout() + + # Expectation: + # - First 3 requests form one batch → sent to replica R1 (+3). + # - Remaining 1 request forms its own batch → goes to replica R2 (+1). + # So totals should be [3, 1] (order depends on round robin). + assert sum(values) == 4, f"Expected total=4, got {values}" + assert sorted(values) == [1, 3], f"Expected [1, 3], got {values}" + + # --- Second batch: rr_batch_incr_bsize5 (batch_size = 5) --- + tasks = [ + asyncio.create_task(service.rr_batch_incr_bsize5.route()) for _ in range(7) + ] + await asyncio.gather(*tasks) + + values = await service.value.fanout() + + # Expectation (RoundRobin between replicas): + # Starting from previous state (R1=3, R2=1): + # - Next 5 requests form one batch → go to R1 (+5). + # - Remaining 2 requests form their own batch → go to R2 (+2). + # + # Final totals: + # R1 = 3 (previous) + 5 = 8 + # R2 = 1 (previous) + 2 = 3 + # So distribution should be [3, 8]. + assert sum(values) == 11, f"Expected total=11, got {values}" + assert sorted(values) == [3, 8], f"Expected [4, 8], got {values}" + + finally: + await service.shutdown() diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 31a912542..64882d243 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -13,15 +13,8 @@ import pytest from forge.controller import ForgeActor -from forge.controller.service import ( - LeastLoadedRouter, - Replica, - ReplicaState, - RoundRobinRouter, - ServiceConfig, - SessionRouter, -) -from forge.types import ProcessConfig +from forge.controller.service import ServiceConfig + from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) @@ -63,20 +56,6 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: return self.v -def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: - """Helper to build a replica with specified state and load.""" - replica = Replica( - idx=idx, - proc_config=ProcessConfig(), - actor_def=Counter, - actor_args=(), - actor_kwargs={}, - ) - replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY - replica.active_requests = load - return replica - - # Actor Tests @@ -754,95 +733,3 @@ async def test_broadcast_fanout_vs_route(): finally: await service.shutdown() - - -# Router Tests - - -@pytest.mark.asyncio -async def test_session_router_with_round_robin_fallback(): - """Switch fallback router to round-robin and verify assignment order.""" - # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = RoundRobinRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx != r2.idx - assert set(session_map.values()) == {0, 1} - - # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = LeastLoadedRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx == r2.idx == 0 - - -# Router integeration tests - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_round_robin_router_distribution(): - """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" - service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) - - try: - # Make multiple sessionless calls using route() - results = [] - for _ in range(6): - await service.incr.route() - values = await service.value.fanout() - print(values) - results.append(values) - print("results: ", results) - # Verify that requests were distributed round-robin - # Each call increments a single replica, so after 6 calls we expect: - # 2 increments per replica (since 3 replicas, 6 calls) - final_values = results[-1] # last snapshot - assert sorted(final_values) == [2, 2, 2] - - finally: - await service.shutdown() - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_session_router_assigns_and_updates_session_map_in_service(): - """Integration: Service with SessionRouter preserves sticky sessions.""" - # Use LeastLoaded as default, SessionRouter (with fallback) is always active - service = await Counter.options( - procs=1, - num_replicas=2, - ).as_service(v=0) - - try: - # First call with sess_id -> assign a replica - await service.incr.route(sess_id="sess1") - values1 = await service.value.fanout() - - # Second call with same sess_id -> must hit same replica - await service.incr.route(sess_id="sess1") - values2 = await service.value.fanout() - - # Difference should only be on one replica (sticky session) - diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] - assert ( - sum(diffs) == 1 - ), f"Expected exactly one replica to increment, got {diffs}" - assert max(diffs) == 1 and min(diffs) == 0 - - # Session map in service should reflect assigned replica - assigned_idx = service._session_replica_map["sess1"] - assert values2[assigned_idx] == values1[assigned_idx] + 1 - - finally: - await service.shutdown()