diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index f0d8fca7b..d745f1556 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .endpoint import service_endpoint, ServiceEndpointProperty from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState -from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter +from .router import Batcher, LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ @@ -24,4 +25,8 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", + "service_endpoint", + "ServiceEndpointProperty", + "Router", + "Batcher", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py new file mode 100644 index 000000000..ba29034b0 --- /dev/null +++ b/src/forge/controller/service/endpoint.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Service endpoint management for the Forge framework. +""" + +from typing import Any, Callable, Generic, List, TypeVar + +from monarch._src.actor.endpoint import EndpointProperty + +from typing_extensions import ParamSpec + +from .router import RoundRobinRouter, Router + +P = ParamSpec("P") +R = TypeVar("R") +Propagator = Any + + +class ServiceEndpoint(Generic[P, R]): + """ + This extends Monarch's actor APIs for service endpoints. + - `route(*args, **kwargs)`: Routes the request to a single replica. + - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. + + Monarch's native actor APIs do not apply for services. + """ + + def __init__( + self, + service, + endpoint_name: str, + ): + self.service = service + self.endpoint_name = endpoint_name + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.service._route(sess_id, self.endpoint_name, *args, **kwargs) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.service._fanout(self.endpoint_name, *args, **kwargs) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class ServiceEndpointV2(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, actor_mesh, endpoint_name: str): + self.actor_mesh = actor_mesh + self.endpoint_name = endpoint_name + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.actor_mesh.call.call_one( + sess_id, self.endpoint_name, *args, **kwargs + ) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.actor_mesh.call_all.call_one( + self.endpoint_name, *args, **kwargs + ) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class ServiceEndpointProperty(EndpointProperty, Generic[P, R]): + """ + Extension of EndpointProperty that carries service-specific + routing and batching configuration. + + Inherits from EndpointProperty so the method is still registered as + a valid actor endpoint, while also attaching service-specific options + (router, batch_size, batch_timeout). + """ + + def __init__( + self, + method: Any, + propagator: Propagator, + explicit_response_port: bool, + *, + router: Callable[[], Router] = RoundRobinRouter, + batch_size: int = 1, + batch_timeout: float = 0.01, + ) -> None: + super().__init__(method, propagator, explicit_response_port) + self.router = router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + + +def service_endpoint( + *, + router: Callable[[], Router] = RoundRobinRouter, + batch_size: int = 1, + batch_timeout: float = 0.01, + propagate=None, + explicit_response_port=False, +): + """ + Marks an actor method as a service endpoint with batching routing support. + + Example: + class MyForgeActor(ForgeActor): + @service_endpoint(router=RoundRobinRouter, batch_size=16, batch_timeout=0.05) + async def predict(self, x): ... + """ + + def decorator(method) -> ServiceEndpointProperty: + return ServiceEndpointProperty( + method, + propagator=propagate, + explicit_response_port=explicit_response_port, + router=router, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + + return decorator diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 5b7e2f884..8e4a0c461 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -11,20 +11,11 @@ """ import contextvars -import logging -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty -from .replica import Replica - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -P = ParamSpec("P") -R = TypeVar("R") +from .endpoint import ServiceEndpoint, ServiceEndpointProperty, ServiceEndpointV2 @dataclass @@ -77,94 +68,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.session_id = None -class ServiceEndpoint(Generic[P, R]): - """ - This extends Monarch's actor APIs for service endpoints. - - `route(*args, **kwargs)`: Routes the request to a single replica. - - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. - - Monarch's native actor APIs do not apply for services. - """ - - def __init__(self, service, endpoint_name: str): - self.service = service - self.endpoint_name = endpoint_name - - async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) - - async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.service.call_all(self.endpoint_name, *args, **kwargs) - return result - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use choose() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use call() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use a call_one() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use broadcast() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def generate(self, *args: P.args, **kwargs: P.kwargs): - raise NotImplementedError( - "You tried to use generate() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - -class ServiceEndpointV2(Generic[P, R]): - """An endpoint object specific to services. - - This loosely mimics the Endpoint APIs exposed in Monarch, with - a few key differences: - - Only choose and call are retained (dropping stream and call_one) - - Call returns a list directly rather than a ValueMesh. - - These changes are made with Forge use cases in mind, but can - certainly be expanded/adapted in the future. - - """ - - def __init__(self, actor_mesh, endpoint_name: str): - self.actor_mesh = actor_mesh - self.endpoint_name = endpoint_name - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.actor_mesh.call.call_one( - sess_id, self.endpoint_name, *args, **kwargs - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.actor_mesh.call_all.call_one( - self.endpoint_name, *args, **kwargs - ) - return result - - class ServiceInterface: """ A lightweight interface to the base Service class. @@ -182,10 +85,15 @@ def __init__(self, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpoint(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpoint(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -260,10 +168,15 @@ def __init__(self, _proc_mesh, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpointV2(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpointV2(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -306,17 +219,3 @@ def __getattr__(self, name: str): raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) - - -class Router(ABC): - """Abstract base class for routing logic.""" - - @abstractmethod - def get_replica( - self, - healthy_replicas: List[Replica], - sess_id: str | None = None, - session_map: Dict[str, int] | None = None, - ) -> Replica: - """Select a replica from the list based on routing logic.""" - pass diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 502402e36..76456b0f1 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -4,16 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +import asyncio import logging -from typing import Dict, List +import time +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List -from .interface import Router -from .replica import Replica +from forge.controller.service.replica import Replica, ServiceRequest logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +class Router(ABC): + """Abstract base class for routing logic.""" + + @abstractmethod + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + """Select a replica from the list based on routing logic.""" + pass + + class RoundRobinRouter(Router): """Round-robin router for stateless requests.""" @@ -88,3 +105,181 @@ def get_replica( replica.idx, ) return replica + + +class Batcher: + """ + Asynchronous batching wrapper around a Router. + + Instead of selecting a replica immediately, incoming requests are enqueued + and grouped into batches. Once a batch is ready (either reaching the maximum + size or exceeding the maximum wait time), the batcher makes a single routing + decision using the inner router. All requests in that batch are then resolved + with the same replica. + + This reduces router overhead by amortizing multiple requests into one decision. + + Args: + inner_router: The underlying Router used to pick a replica. + get_healthy_replicas: Callable that returns the current list of healthy replicas. + get_session_map: Callable that returns the session-to-replica mapping. + batch_size: Maximum number of requests to collect in a single batch + before routing (default: 8). + batch_timeout: Maximum time to wait (in seconds) before routing a batch, + even if batch_size is not reached (default: 0.01). + + Example: + rr_router = RoundRobinRouter() + batcher = Batcher( + rr_router, + get_healthy_replicas=service._get_healthy_replicas, + get_session_map=service._get_session_map, + batch_size=16, + batch_timeout=0.01, + ) + + # Enqueue a endpoint call to be sent to a replica + results = await batcher.route(function, args, kwargs) + """ + + def __init__( + self, + function: str, + inner_router: Router, + get_healthy_replicas: Callable[[], List["Replica"]], + get_session_map: Callable[[], Dict[str, int]], + batch_size: int = 16, + batch_timeout: float = 0.01, + ): + self.function = function + self.inner_router = inner_router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.get_healthy_replicas = get_healthy_replicas + self.get_session_map = get_session_map + + # Internal queue for batching routing requests + self._queue: asyncio.Queue = asyncio.Queue() + self._running = True # flag to control loop + # Background task that processes batches continuously + self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) + # Maximum number of routing attempts per batch (1 initial + 1 retry if replica fails) + self._num_attempts = 2 + + async def _batch_loop(self) -> None: + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + Each iteration collects individual (function, args, kwargs, future) entries from + the internal queue and merges them into a single `ServiceRequest` that is then + dispatched to one replica. + + The loop follows these steps: + 1. Wait for the first queued call to start a new batch + 2. Continue collecting until batch_size or batch_timeout is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + + Returns: + None + + Raises: + RuntimeError: If no healthy replicas are available + Exception: Any exception raised by the actor endpoint + """ + while self._running: + + # Wait for first request + batch = [await self._queue.get()] + start_time = time.monotonic() + + # TODO (dxie): consider making timeout adaptive based on replica load. + while True: + try: + timeout = max( + 0, self.batch_timeout - (time.monotonic() - start_time) + ) + nxt = await asyncio.wait_for( + self._queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch.append(nxt) + + if len(batch) >= self.batch_size: + break + except asyncio.TimeoutError: + break + + # Merge args for batched call + if batch and len(batch[0][1]) > 0: + # Normal case: endpoints expect positional arguments + merged_args = [list(items) for items in zip(*[b[1] for b in batch])] + args = tuple(merged_args) + else: + # No-arg case: just one batched call, no inputs to merge + args = () + + for attempt in range(self._num_attempts): + session_map = self.get_session_map() + healthy_replicas = self.get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.inner_router.get_replica( + healthy_replicas, None, session_map + ) + + batch_req = ServiceRequest( + session_id=None, + function=self.function, + args=args, + kwargs={}, + future=asyncio.Future(), + ) + + try: + # Send whole batch to replica + await replica.enqueue_request(batch_req) + results = await batch_req.future + # Normalize result shape + if isinstance(results, (list, tuple)): + results = list(results) + if len(results) != len(batch): + results = [results] * len(batch) + else: + # scalar result, broadcast to batch size + results = [results] * len(batch) + + # Fulfill each individual Future from the batch + for (_, _, _, f), r in zip(batch, results): + if not f.done(): + f.set_result(r) + break + + except Exception as e: + if attempt < self._num_attempts - 1 and not replica.healthy: + logger.debug( + f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" + ) + continue + else: + for _, _, _, f in batch: + if not f.done(): + f.set_exception(e) + + async def route(self, function: str, args: tuple, kwargs: dict) -> Any: + """Add (args, kwargs) pair to queue, return a Future resolved when batch completes.""" + # Queue the request for batching + fut = asyncio.Future() + self._queue.put_nowait((function, args, kwargs, fut)) + # Wait for the batch processor to resolve our future + return await fut + + async def stop(self): + """Stop the batch loop gracefully.""" + self._running = False + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 0b655fb6a..d1b05284b 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -40,14 +40,16 @@ from monarch.actor import Actor, endpoint +from forge.controller.service.endpoint import ServiceEndpointProperty + from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest - from forge.controller.service.router import ( - LeastLoadedRouter, + Batcher, RoundRobinRouter, + Router, SessionRouter, ) from forge.types import ServiceConfig @@ -108,7 +110,10 @@ async def __initialize__(self): # Initialize the routers self._default_router = RoundRobinRouter() - self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) + self._session_router = SessionRouter(fallback_router=self._default_router) + + # This keeps the map between the registered endpoints and the routers + self.routers: dict[str, Router | Batcher] = {} # Initialize all replicas replicas = [] @@ -138,14 +143,67 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + def _set_router( + self, endpoint_name: str, prop: ServiceEndpointProperty | None = None + ) -> None: + """ + Ensure a router exists for the given endpoint. + + - If a router is already set, raise an error. + - If a ServiceEndpointProperty is provided, construct its router/batcher + using the specified configuration. + - If not provided, fall back to a default round-robin router. + + Args: + endpoint_name: Name of the endpoint. + prop: Optional ServiceEndpointProperty object with router, batch_size, + and batch_timeout attributes. + """ + + # If router already exists, raise an exception + if endpoint_name in self.routers: + raise AssertionError(f"Router already exists for endpoint: {endpoint_name}") + + # If config is missing or incomplete, use default router + if prop is None or not isinstance(prop, ServiceEndpointProperty): + return + + # Resolve base router + if not callable(prop.router): + raise ValueError(f"Router must be callable, got: {prop.router}") + else: + base_router = prop.router() # Call the router constructor + batch_size = prop.batch_size + batch_timeout = prop.batch_timeout + + if not isinstance(base_router, Router): + raise ValueError( + f"Router must be a Router instance, got: {type(base_router)}" + ) + + # Wrap in Batcher if batching requested + if batch_size > 1: + router = Batcher( + endpoint_name, + base_router, + get_healthy_replicas=self._get_healthy_replicas, + get_session_map=self._get_session_map, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + else: + router = base_router + + # Store and return + self.routers[endpoint_name] = router + + async def _route(self, sess_id: str | None, function: str, *args, **kwargs): """ Routes a function call to the appropriate replica with load balancing and fault tolerance. This is the core routing method that handles: - Session-based routing for stateful calls - - Round-robin load balancing for stateless calls - - Custom routing based on context hints + - Per-endpoint router selection (round-robin, least-loaded, batching, etc.) - Automatic retry on replica failures - Request queuing and processing @@ -168,8 +226,15 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): if ctx: sess_id = ctx["session_id"] - replica = await self._get_replica(sess_id) + router = self.routers.get(function, self._default_router) + # Case 1: batching is enabled and no session ID (stateless calls only) + # Forward the request into the Batcher queue. The batcher will send + # the batched requests to the selected replica. + if sess_id is None and isinstance(router, Batcher): + return await router.route(function, args, kwargs) + + # Case 2: route a single request to a replica # Create a ServiceRequest object to queue request = ServiceRequest( session_id=sess_id, @@ -179,6 +244,18 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): future=asyncio.Future(), ) + healthy_replicas = [r for r in self._replicas if r.healthy] + + # Select replica: + if sess_id is not None: + # Case 2.1: sticky sessions + replica = self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + else: + # Case 2.2: stateless routing + replica = self._default_router.get_replica(healthy_replicas) + # Queue the request using replica's method await replica.enqueue_request(request) @@ -196,7 +273,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - async def call_all(self, function: str, *args, **kwargs) -> List: + async def _fanout(self, function: str, *args, **kwargs) -> List: """ Broadcasts a function call to all healthy replicas and returns results as a list. @@ -211,7 +288,7 @@ async def call_all(self, function: str, *args, **kwargs) -> List: Raises: RuntimeError: If no healthy replicas are available """ - healthy_replicas = [r for r in self._replicas if r.healthy] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: raise RuntimeError("No healthy replicas available for broadcast call") @@ -256,7 +333,7 @@ async def _retry_request_on_healthy_replica( del self._session_replica_map[sess_id] # Retry the call (this will assign to a new healthy replica) - return await self._call(sess_id, function, *args, **kwargs) + return await self._route(sess_id, function, *args, **kwargs) async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" @@ -280,9 +357,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): ) # Find healthy replicas - healthy_replicas = [ - r for r in self._replicas if r.healthy and r != failed_replica - ] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: # No healthy replicas, fail all requests @@ -334,7 +409,7 @@ def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + self._metrics.healthy_replicas = len(self._get_healthy_replicas()) # Store direct references to replica metrics for aggregation self._metrics.replica_metrics = {} for replica in self._replicas: @@ -446,6 +521,13 @@ async def terminate_session(self, sess_id: str): # Update metrics self._update_service_metrics() + def _get_healthy_replicas(self) -> list[Replica]: + """Returns a list of healthy replicas.""" + return [r for r in self._replicas if r.healthy] + + def _get_session_map(self) -> Dict[str, int]: + return self._session_replica_map + async def _health_loop(self, poll_rate_s: float): """Runs the health loop to monitor and recover replicas. @@ -474,17 +556,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _get_replica(self, sess_id: str | None) -> "Replica": - """Get a replica for the given session ID.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if sess_id is None: - # No session, use the default router - return self._default_router.get_replica(healthy_replicas) - - return self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop @@ -503,6 +574,15 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all batchers in routers + # Stop all batchers + batchers = [ + router for router in self.routers.values() if isinstance(router, Batcher) + ] + if batchers: + await asyncio.gather(*(b.stop() for b in batchers), return_exceptions=True) + logger.info("All batcher loop(s) stopped gracefully.") + # Stop all replicas using their stop method await asyncio.gather( *[replica.stop() for replica in self._replicas], @@ -582,7 +662,7 @@ async def _get_internal_state(self) -> dict: # Load balancing state # Service-level state "total_replicas": len(self._replicas), - "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), + "healthy_replica_count": len(self._get_healthy_replicas()), "shutdown_requested": self._shutdown_requested, # Metrics summary "total_sessions": len(self._active_sessions), diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py new file mode 100644 index 000000000..9d0522f93 --- /dev/null +++ b/tests/unit_tests/test_router.py @@ -0,0 +1,401 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Tests for router.py +""" + +import asyncio +import logging + +import pytest +from forge.controller import ForgeActor +from forge.controller.service import ( + Batcher, + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + service_endpoint, + SessionRouter, +) + +from forge.types import ProcessConfig +from monarch.actor import endpoint + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Counter(ForgeActor): + """Test actor that maintains a counter with various endpoints.""" + + def __init__(self, v: int): + self.v = v + self._num_calls = 0 # number of calls to endpoint functions + + @endpoint + async def value(self) -> int: + """Get the current counter value.""" + return self.v + + @endpoint + async def fail_me(self): + """Endpoint that always fails to test error handling.""" + raise RuntimeError("I was asked to fail") + + @endpoint + async def get_num_calls(self): + """Get the number of calls to endpoint functions.""" + return self._num_calls + + @endpoint + async def incr(self): + """Increment the counter.""" + self._num_calls += 1 + self.v += 1 + + @service_endpoint(router=RoundRobinRouter, batch_size=3, batch_timeout=1) + async def rr_batch_incr_bsize3(self): + """Increment the round-robin counter with batching (batch size = 3).""" + self._num_calls += 1 + self.v += 1 + + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) + async def rr_batch_incr_bsize5(self, inputs: list[int]) -> list[int]: + """Increment the round-robin counter with batching (batch size = 5).""" + self._num_calls += 1 + self.v += sum(inputs) + return inputs + + @service_endpoint(router=RoundRobinRouter) + async def rr_batch_incr_bsize1(self, inputs: list[int]) -> list[int]: + """Increment the round-robin counter with batching (batch size = 1).""" + self._num_calls += 1 + self._sum += sum(inputs) + return inputs + + +def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: + """Helper to build a replica with specified state and load.""" + replica = Replica( + idx=idx, + proc_config=ProcessConfig(), + actor_def=Counter, + actor_args=(), + actor_kwargs={}, + ) + replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY + replica.active_requests = load + return replica + + +# Cnofig tests + + +@pytest.mark.asyncio +async def test_service_endpoint_router_and_configurations(): + """ + Verify service endpoints are registered with correct router/batching configuration: + - rr_batch_incr_bsize1: plain RoundRobinRouter, no batching (batch_size=1, timeout=0.01) + - rr_batch_incr_bsize3: Batcher wrapping RoundRobinRouter (batch_size=3, timeout=1) + - incr: plain @endpoint, should not appear in service.routers + """ + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # --- rr_batch_incr_bsize1 --- + router1 = service.routers.get("rr_batch_incr_bsize1") + assert isinstance( + router1, RoundRobinRouter + ), f"Expected RoundRobinRouter, got {type(router1)}" + + prop1 = Counter.rr_batch_incr_bsize1 + assert prop1.batch_size == 1 + assert prop1.batch_timeout == 0.01 + + # --- rr_batch_incr_bsize3 --- + router3 = service.routers.get("rr_batch_incr_bsize3") + + assert isinstance(router3, Batcher), f"Expected Batcher, got {type(router3)}" + assert router3.batch_size == 3 + assert router3.batch_timeout == 1 + + # --- incr --- + assert ( + "incr" not in service.routers + ), "Plain @endpoint should not be in service.routers" + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_noncallable(): + """@service_endpoint with non-callable router should raise ValueError.""" + + class BadActor(ForgeActor): + @service_endpoint(router="roundrobin") # string, not callable + async def bad_endpoint(self): + return 42 + + with pytest.raises(ValueError, match="Router must be callable"): + # Triggers ServiceInterface._set_router during construction + await BadActor.options(num_replicas=1).as_service() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_wrong_return_type(): + """@service_endpoint with callable that doesn't return Router should raise ValueError.""" + + class NotARouter: + """Dummy class that is not a Router.""" + + class BadActor(ForgeActor): + @service_endpoint(router=NotARouter) # returns NotARouter + async def bad_endpoint(self): + return 123 + + with pytest.raises(ValueError, match="Router must be a Router instance"): + await BadActor.options(num_replicas=1).as_service() + + +# Router Tests + + +@pytest.mark.asyncio +async def test_session_router_fallback_rr_vs_ll(): + """Switch fallback router to round-robin and verify assignment order.""" + # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = RoundRobinRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx != r2.idx + assert set(session_map.values()) == {0, 1} + + # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx == r2.idx == 0 + + +# Router integeration tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + for _ in range(6): + await service.incr.route() + values = await service.value.fanout() + results.append(values) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + final_values = results[-1] # last snapshot + assert sorted(final_values) == [2, 2, 2] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map_in_service(): + """Integration: Service with SessionRouter preserves sticky sessions.""" + service = await Counter.options( + procs=1, + num_replicas=2, + ).as_service(v=0) + + try: + # First call with sess_id -> assign a replica + await service.incr.route(sess_id="sess1") + values1 = await service.value.fanout() + + # Second call with same sess_id -> must hit same replica + await service.incr.route(sess_id="sess1") + values2 = await service.value.fanout() + + # Difference should only be on one replica (sticky session) + diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] + assert ( + sum(diffs) == 1 + ), f"Expected exactly one replica to increment, got {diffs}" + assert max(diffs) == 1 and min(diffs) == 0 + + # Session map in service should reflect assigned replica + assigned_idx = service._session_replica_map["sess1"] + assert values2[assigned_idx] == values1[assigned_idx] + 1 + + finally: + await service.shutdown() + + +# Batcher tests + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_as_actor_preserves_normal_usage(): + """Ensure that using `as_actor` does not break normal semantics.""" + service = await Counter.as_actor(5) + + try: + assert await service.value.choose() == 5 + + # Test increment + await service.rr_batch_incr_bsize3.choose() + assert await service.value.choose() == 6 + + finally: + await Counter.shutdown(service) + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_rr_batch_incr_bsize5_behaves_like_normal_incr(): + """Ensure rr_batch_incr_bsize5 (batch_size=5) behaves like a normal incr endpoint for single calls.""" + service = await Counter.options(procs=1, num_replicas=1).as_service(v=5) + + try: + # Initial value + assert await service.value.route() == 5 + + # Call batched increment once + await service.rr_batch_incr_bsize5.route(1) + + # Should increment exactly once + assert await service.value.route() == 6 + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_endpoint_batching_preserves_order(): + """Ensure that batching preserves the order of calls.""" + service = await Counter.options(num_replicas=2, procs=1).as_service(0) + try: + results = await asyncio.gather( + *[service.rr_batch_incr_bsize5.route(i) for i in range(5)] + ) + assert results == [0, 1, 2, 3, 4] + assert await service.get_num_calls.route() == 1 + assert sorted(await service.value.fanout()) == [0, 10] + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_endpoint_multiple_batches(): + """ + Verify that batching correctly splits requests into two batches — + one triggered by reaching the batch size limit and another by the batch timeout. + """ + service = await Counter.options(num_replicas=2, procs=1).as_service(0) + try: + # Enqueue 7 calls → expect two batches (5 + 2) + results = await asyncio.gather( + *[service.rr_batch_incr_bsize5.route(i) for i in range(7)] + ) + # Verify all individual results were returned in order + assert results == [0, 1, 2, 3, 4, 5, 6] + # Each replica should have executed one batch (round-robin) + assert await service.get_num_calls.fanout() == [1, 1] + + # Replica values reflect the sum of their respective batch inputs + # first batch: [0, 1, 2, 3, 4] → 10 + # second batch: [5, 6] → 11 + assert sorted(await service.value.fanout()) == [10, 11] + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_round_robin_batcher_distribution_no_args(): + """ + Verify that the batching system correctly handles endpoints with **zero arguments** + and that the RoundRobinRouter distributes such batched calls evenly across replicas. + """ + + # --- Launch service with 3 replicas --- + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Enqueue 5 no-arg batched calls + await asyncio.gather(*[service.rr_batch_incr_bsize3.route() for _ in range(5)]) + + # Check that two replicas incremented their counters once + values = await service.value.fanout() + assert sorted(values) == [0, 1, 1], f"Unexpected replica values: {values}" + + # Ensure exactly 2 actor invocations occurred (2 batches total) + num_calls = await service.get_num_calls.fanout() + assert sum(num_calls) == 2, f"Expected 2 batches, got {sum(num_calls)}" + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_endpoint_batching_multi_arg_merge(): + """Ensure that batching merges multiple argument lists correctly.""" + + class MultiArgActor(ForgeActor): + def __init__(self): + self._num_calls = 0 + + @endpoint + async def get_num_calls(self): + return self._num_calls + + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.1) + async def multi_args_sum(self, v1: list[int], v2: list[str]) -> list[str]: + """ + Endpoint that accepts multiple argument lists. + Should be invoked once per batch. + """ + self._num_calls += 1 + # Combine corresponding elements + return [f"{x}:{y}" for x, y in zip(v1, v2)] + + service = await MultiArgActor.options(num_replicas=2, procs=1).as_service() + + try: + # 5 requests will fill one batch of size 5 + results = await asyncio.gather( + *[service.multi_args_sum.route(i, str(i)) for i in range(5)] + ) + + # Expect exactly one actor invocation + assert await service.get_num_calls.route() == 1 + + # Expect results correspond to all merged pairs + assert results == ["0:0", "1:1", "2:2", "3:3", "4:4"] + + finally: + await service.shutdown() diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 31a912542..64882d243 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -13,15 +13,8 @@ import pytest from forge.controller import ForgeActor -from forge.controller.service import ( - LeastLoadedRouter, - Replica, - ReplicaState, - RoundRobinRouter, - ServiceConfig, - SessionRouter, -) -from forge.types import ProcessConfig +from forge.controller.service import ServiceConfig + from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) @@ -63,20 +56,6 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: return self.v -def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: - """Helper to build a replica with specified state and load.""" - replica = Replica( - idx=idx, - proc_config=ProcessConfig(), - actor_def=Counter, - actor_args=(), - actor_kwargs={}, - ) - replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY - replica.active_requests = load - return replica - - # Actor Tests @@ -754,95 +733,3 @@ async def test_broadcast_fanout_vs_route(): finally: await service.shutdown() - - -# Router Tests - - -@pytest.mark.asyncio -async def test_session_router_with_round_robin_fallback(): - """Switch fallback router to round-robin and verify assignment order.""" - # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = RoundRobinRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx != r2.idx - assert set(session_map.values()) == {0, 1} - - # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = LeastLoadedRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx == r2.idx == 0 - - -# Router integeration tests - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_round_robin_router_distribution(): - """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" - service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) - - try: - # Make multiple sessionless calls using route() - results = [] - for _ in range(6): - await service.incr.route() - values = await service.value.fanout() - print(values) - results.append(values) - print("results: ", results) - # Verify that requests were distributed round-robin - # Each call increments a single replica, so after 6 calls we expect: - # 2 increments per replica (since 3 replicas, 6 calls) - final_values = results[-1] # last snapshot - assert sorted(final_values) == [2, 2, 2] - - finally: - await service.shutdown() - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_session_router_assigns_and_updates_session_map_in_service(): - """Integration: Service with SessionRouter preserves sticky sessions.""" - # Use LeastLoaded as default, SessionRouter (with fallback) is always active - service = await Counter.options( - procs=1, - num_replicas=2, - ).as_service(v=0) - - try: - # First call with sess_id -> assign a replica - await service.incr.route(sess_id="sess1") - values1 = await service.value.fanout() - - # Second call with same sess_id -> must hit same replica - await service.incr.route(sess_id="sess1") - values2 = await service.value.fanout() - - # Difference should only be on one replica (sticky session) - diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] - assert ( - sum(diffs) == 1 - ), f"Expected exactly one replica to increment, got {diffs}" - assert max(diffs) == 1 and min(diffs) == 0 - - # Session map in service should reflect assigned replica - assigned_idx = service._session_replica_map["sess1"] - assert values2[assigned_idx] == values1[assigned_idx] + 1 - - finally: - await service.shutdown()