From b3252be6cf7b9921c0da4c62b447a9c948dd36ee Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 3 Oct 2025 11:26:02 -0700 Subject: [PATCH 1/6] catch up where we left --- src/forge/controller/service/__init__.py | 7 +- src/forge/controller/service/endpoint.py | 199 +++++++++++++ src/forge/controller/service/interface.py | 139 ++------- src/forge/controller/service/replica.py | 14 +- src/forge/controller/service/router.py | 148 +++++++++- src/forge/controller/service/service.py | 135 +++++++-- tests/unit_tests/test_router.py | 327 ++++++++++++++++++++++ tests/unit_tests/test_service.py | 117 +------- 8 files changed, 814 insertions(+), 272 deletions(-) create mode 100644 src/forge/controller/service/endpoint.py create mode 100644 tests/unit_tests/test_router.py diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index f0d8fca7b..d745f1556 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .endpoint import service_endpoint, ServiceEndpointProperty from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState -from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter +from .router import Batcher, LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ @@ -24,4 +25,8 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", + "service_endpoint", + "ServiceEndpointProperty", + "Router", + "Batcher", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py new file mode 100644 index 000000000..c731017c6 --- /dev/null +++ b/src/forge/controller/service/endpoint.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Service endpoint management for the Forge framework. +""" + +from typing import Any, Callable, Generic, List, TypeVar + +from monarch._src.actor.endpoint import EndpointProperty + +from typing_extensions import ParamSpec + +from .router import RoundRobinRouter, Router + +P = ParamSpec("P") +R = TypeVar("R") +Propagator = Any + + +class ServiceEndpoint(Generic[P, R]): + """ + This extends Monarch's actor APIs for service endpoints. + - `route(*args, **kwargs)`: Routes the request to a single replica. + - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. + + Monarch's native actor APIs do not apply for services. + """ + + def __init__( + self, + service, + endpoint_name: str, + ): + self.service = service + self.endpoint_name = endpoint_name + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.service._route(sess_id, self.endpoint_name, *args, **kwargs) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.service._fanout(self.endpoint_name, *args, **kwargs) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class ServiceEndpointV2(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, actor_mesh, endpoint_name: str): + self.actor_mesh = actor_mesh + self.endpoint_name = endpoint_name + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.actor_mesh.call.call_one( + sess_id, self.endpoint_name, *args, **kwargs + ) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.actor_mesh.call_all.call_one( + self.endpoint_name, *args, **kwargs + ) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class ServiceEndpointProperty(EndpointProperty, Generic[P, R]): + """ + Extension of EndpointProperty that carries service-specific + routing and batching configuration. + + Inherits from EndpointProperty so the method is still registered as + a valid actor endpoint, while also attaching service-specific options + (router, batch_size, batch_timeout). + """ + + def __init__( + self, + method: Any, + propagator: Propagator, + explicit_response_port: bool, + *, + router: Callable[[], Router] = RoundRobinRouter, + batch_size: int = 1, + batch_timeout: float = 0.01, + ) -> None: + super().__init__(method, propagator, explicit_response_port) + self.router = router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + + +def service_endpoint( + *, + router: Callable[[], Router] = RoundRobinRouter, + batch_size: int = 1, + batch_timeout: float = 0.01, + propagate=None, + explicit_response_port=False, +): + """ + Marks an actor method as a service endpoint with batching routing support. + + Example: + class MyForgeActor(ForgeActor): + @service_endpoint(router=RoundRobinRouter(), batch_size=16, batch_timeout=0.05) + async def predict(self, x): ... + """ + + def decorator(method) -> ServiceEndpointProperty: + return ServiceEndpointProperty( + method, + propagator=propagate, + explicit_response_port=explicit_response_port, + router=router, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + + return decorator diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 5b7e2f884..8e4a0c461 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -11,20 +11,11 @@ """ import contextvars -import logging -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty -from .replica import Replica - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -P = ParamSpec("P") -R = TypeVar("R") +from .endpoint import ServiceEndpoint, ServiceEndpointProperty, ServiceEndpointV2 @dataclass @@ -77,94 +68,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.session_id = None -class ServiceEndpoint(Generic[P, R]): - """ - This extends Monarch's actor APIs for service endpoints. - - `route(*args, **kwargs)`: Routes the request to a single replica. - - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. - - Monarch's native actor APIs do not apply for services. - """ - - def __init__(self, service, endpoint_name: str): - self.service = service - self.endpoint_name = endpoint_name - - async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) - - async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.service.call_all(self.endpoint_name, *args, **kwargs) - return result - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use choose() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use call() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use a call_one() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use broadcast() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def generate(self, *args: P.args, **kwargs: P.kwargs): - raise NotImplementedError( - "You tried to use generate() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - -class ServiceEndpointV2(Generic[P, R]): - """An endpoint object specific to services. - - This loosely mimics the Endpoint APIs exposed in Monarch, with - a few key differences: - - Only choose and call are retained (dropping stream and call_one) - - Call returns a list directly rather than a ValueMesh. - - These changes are made with Forge use cases in mind, but can - certainly be expanded/adapted in the future. - - """ - - def __init__(self, actor_mesh, endpoint_name: str): - self.actor_mesh = actor_mesh - self.endpoint_name = endpoint_name - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.actor_mesh.call.call_one( - sess_id, self.endpoint_name, *args, **kwargs - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.actor_mesh.call_all.call_one( - self.endpoint_name, *args, **kwargs - ) - return result - - class ServiceInterface: """ A lightweight interface to the base Service class. @@ -182,10 +85,15 @@ def __init__(self, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpoint(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpoint(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -260,10 +168,15 @@ def __init__(self, _proc_mesh, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpointV2(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpointV2(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -306,17 +219,3 @@ def __getattr__(self, name: str): raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) - - -class Router(ABC): - """Abstract base class for routing logic.""" - - @abstractmethod - def get_replica( - self, - healthy_replicas: List[Replica], - sess_id: str | None = None, - session_map: Dict[str, int] | None = None, - ) -> Replica: - """Select a replica from the list based on routing logic.""" - pass diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 09b0a2ce6..81c6de0e1 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -9,7 +9,7 @@ import logging import time from collections import deque -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Optional @@ -159,10 +159,9 @@ async def initialize(self): # Deploy the actor and its underlying resources logger.debug(f"Launching actor for replica {self.idx}") - self.actor = await self.actor_def.launch( - *self.actor_args, - **self.actor_kwargs, - ) + self.actor = await self.actor_def.options( + **asdict(self.proc_config) + ).as_actor(*self.actor_args, **self.actor_kwargs) # Transition to healthy state and start processing self.state = ReplicaState.HEALTHY self.start_processing() @@ -221,6 +220,11 @@ async def enqueue_request(self, request: ServiceRequest): # Accept requests in all other states - let the processing loop handle the rest await self.request_queue.put(request) + async def enqueue_batch(self, requests: list[ServiceRequest]): + """Enqueues a batch of requests for processing by this replica.""" + for req in requests: + await self.enqueue_request(req) + async def _process_single_request(self, request: ServiceRequest) -> bool: """Processes a single request and returns success status. diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 502402e36..11311f93f 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -4,16 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +import asyncio import logging -from typing import Dict, List +import time +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List -from .interface import Router -from .replica import Replica +from forge.controller.service.replica import Replica, ServiceRequest logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +class Router(ABC): + """Abstract base class for routing logic.""" + + @abstractmethod + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + """Select a replica from the list based on routing logic.""" + pass + + class RoundRobinRouter(Router): """Round-robin router for stateless requests.""" @@ -88,3 +105,128 @@ def get_replica( replica.idx, ) return replica + + +class Batcher: + """ + Asynchronous batching wrapper around a Router. + + Instead of selecting a replica immediately, incoming requests are enqueued + and grouped into batches. Once a batch is ready (either reaching the maximum + size or exceeding the maximum wait time), the batcher makes a single routing + decision using the inner router. All requests in that batch are then resolved + with the same replica. + + This reduces router overhead by amortizing multiple requests into one decision. + + Args: + inner_router: The underlying Router used to pick a replica. + get_healthy_replicas: Callable that returns the current list of healthy replicas. + get_session_map: Callable that returns the session-to-replica mapping. + batch_size: Maximum number of requests to collect in a single batch + before routing (default: 8). + batch_timeout: Maximum time to wait (in seconds) before routing a batch, + even if batch_size is not reached (default: 0.01). + + Example: + rr_router = RoundRobinRouter() + batcher = Batcher( + rr_router, + get_healthy_replicas=service._get_healthy_replicas, + get_session_map=service._get_session_map, + batch_size=16, + batch_timeout=0.01, + ) + + request = ServiceRequest(...) + + # Enqueue a request to be sent to a replica + await batcher.enqueue(request) + """ + + def __init__( + self, + inner_router: Router, + get_healthy_replicas: Callable[[], List["Replica"]], + get_session_map: Callable[[], Dict[str, int]], + batch_size: int = 16, + batch_timeout: float = 0.01, + ): + + self.inner_router = inner_router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.get_healthy_replicas = get_healthy_replicas + self.get_session_map = get_session_map + + # Internal queue for batching routing requests + self._queue: asyncio.Queue = asyncio.Queue() + self._running = True # flag to control loop + # Background task that processes batches continuously + self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_size or batch_timeout is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running: + + # Wait for first request + batch = [await self._queue.get()] + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self.batch_timeout - (time.monotonic() - start_time) + ) + req = await asyncio.wait_for( + self._queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch.append(req) + + if len(batch) >= self.batch_size: + break + except asyncio.TimeoutError: + break + + session_map = self.get_session_map() + healthy_replicas = self.get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.inner_router.get_replica(healthy_replicas, None, session_map) + + # Send whole batch to replica + try: + await replica.enqueue_batch(batch) + except Exception as e: + for req in batch: + req.future.set_exception(e) + + async def enqueue(self, request: ServiceRequest) -> Any: + """Enqueue request and wait until batch assigns a replica.""" + # Queue the request for batching - this is non-blocking + self._queue.put_nowait(request) + + # Wait for the batch processor to resolve our future + return await request.future + + async def stop(self): + """Stop the batch loop gracefully.""" + self._running = False + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 0b655fb6a..6c182b5ae 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -40,14 +40,16 @@ from monarch.actor import Actor, endpoint +from forge.controller.service.endpoint import ServiceEndpointProperty + from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest - from forge.controller.service.router import ( - LeastLoadedRouter, + Batcher, RoundRobinRouter, + Router, SessionRouter, ) from forge.types import ServiceConfig @@ -108,7 +110,10 @@ async def __initialize__(self): # Initialize the routers self._default_router = RoundRobinRouter() - self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) + self._session_router = SessionRouter(fallback_router=self._default_router) + + # This keeps the map between the registered endpoints and the routers + self.routers: dict[str, Router | Batcher] = {} # Initialize all replicas replicas = [] @@ -138,14 +143,66 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + def _set_router( + self, endpoint_name: str, prop: ServiceEndpointProperty | None = None + ) -> None: + """ + Ensure a router exists for the given endpoint. + + - If a router is already set, raise an error. + - If a ServiceEndpointProperty is provided, construct its router/batcher + using the specified configuration. + - If not provided, fall back to a default round-robin router. + + Args: + endpoint_name: Name of the endpoint. + prop: Optional ServiceEndpointProperty object with router, batch_size, + and batch_timeout attributes. + """ + + # If router already exists, raise an exception + if endpoint_name in self.routers: + raise AssertionError(f"Router already exists for endpoint: {endpoint_name}") + + # If config is missing or incomplete, use default router + if prop is None or not isinstance(prop, ServiceEndpointProperty): + return + + # Resolve base router + if not callable(prop.router): + raise ValueError(f"Router must be callable, got: {prop.router}") + else: + base_router = prop.router() # Call the router constructor + batch_size = prop.batch_size + batch_timeout = prop.batch_timeout + + if not isinstance(base_router, Router): + raise ValueError( + f"Router must be a Router instance, got: {type(base_router)}" + ) + + # Wrap in Batcher if batching requested + if batch_size > 1: + router = Batcher( + base_router, + get_healthy_replicas=self._get_healthy_replicas, + get_session_map=self._get_session_map, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + else: + router = base_router + + # Store and return + self.routers[endpoint_name] = router + + async def _route(self, sess_id: str | None, function: str, *args, **kwargs): """ Routes a function call to the appropriate replica with load balancing and fault tolerance. This is the core routing method that handles: - Session-based routing for stateful calls - - Round-robin load balancing for stateless calls - - Custom routing based on context hints + - Per-endpoint router selection (round-robin, least-loaded, batching, etc.) - Automatic retry on replica failures - Request queuing and processing @@ -168,7 +225,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): if ctx: sess_id = ctx["session_id"] - replica = await self._get_replica(sess_id) + router = self.routers.get(function, self._default_router) # Create a ServiceRequest object to queue request = ServiceRequest( @@ -179,8 +236,27 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): future=asyncio.Future(), ) - # Queue the request using replica's method - await replica.enqueue_request(request) + # Case: batching is enabled and no session ID (stateless calls only) + # Forward the request into the Batcher queue. The batcher will send + # the batched requests to the selected replica. + if sess_id is None and isinstance(router, Batcher): + await router.enqueue(request) + + else: + healthy_replicas = [r for r in self._replicas if r.healthy] + + # Select replica: + if sess_id is not None: + # Case 1: sticky sessions + replica = self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + else: + # Case 2: stateless routing + replica = self._default_router.get_replica(healthy_replicas) + + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: @@ -196,7 +272,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - async def call_all(self, function: str, *args, **kwargs) -> List: + async def _fanout(self, function: str, *args, **kwargs) -> List: """ Broadcasts a function call to all healthy replicas and returns results as a list. @@ -211,7 +287,7 @@ async def call_all(self, function: str, *args, **kwargs) -> List: Raises: RuntimeError: If no healthy replicas are available """ - healthy_replicas = [r for r in self._replicas if r.healthy] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: raise RuntimeError("No healthy replicas available for broadcast call") @@ -256,7 +332,7 @@ async def _retry_request_on_healthy_replica( del self._session_replica_map[sess_id] # Retry the call (this will assign to a new healthy replica) - return await self._call(sess_id, function, *args, **kwargs) + return await self._route(sess_id, function, *args, **kwargs) async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" @@ -280,9 +356,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): ) # Find healthy replicas - healthy_replicas = [ - r for r in self._replicas if r.healthy and r != failed_replica - ] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: # No healthy replicas, fail all requests @@ -334,7 +408,7 @@ def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + self._metrics.healthy_replicas = len(self._get_healthy_replicas()) # Store direct references to replica metrics for aggregation self._metrics.replica_metrics = {} for replica in self._replicas: @@ -446,6 +520,13 @@ async def terminate_session(self, sess_id: str): # Update metrics self._update_service_metrics() + def _get_healthy_replicas(self) -> list[Replica]: + """Returns a list of healthy replicas.""" + return [r for r in self._replicas if r.healthy] + + def _get_session_map(self) -> Dict[str, int]: + return self._session_replica_map + async def _health_loop(self, poll_rate_s: float): """Runs the health loop to monitor and recover replicas. @@ -474,17 +555,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _get_replica(self, sess_id: str | None) -> "Replica": - """Get a replica for the given session ID.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if sess_id is None: - # No session, use the default router - return self._default_router.get_replica(healthy_replicas) - - return self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop @@ -503,6 +573,15 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all batchers in routers + # Stop all batchers + batchers = [ + router for router in self.routers.values() if isinstance(router, Batcher) + ] + if batchers: + await asyncio.gather(*(b.stop() for b in batchers), return_exceptions=True) + logger.info("All batcher loop(s) stopped gracefully.") + # Stop all replicas using their stop method await asyncio.gather( *[replica.stop() for replica in self._replicas], @@ -582,7 +661,7 @@ async def _get_internal_state(self) -> dict: # Load balancing state # Service-level state "total_replicas": len(self._replicas), - "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), + "healthy_replica_count": len(self._get_healthy_replicas()), "shutdown_requested": self._shutdown_requested, # Metrics summary "total_sessions": len(self._active_sessions), diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py new file mode 100644 index 000000000..4ce497e77 --- /dev/null +++ b/tests/unit_tests/test_router.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Tests for router.py +""" + +import asyncio +import logging + +import pytest +from forge.controller import ForgeActor +from forge.controller.service import ( + Batcher, + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + service_endpoint, + SessionRouter, +) + +from forge.types import ProcessConfig +from monarch.actor import endpoint + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Counter(ForgeActor): + """Test actor that maintains a counter with various endpoints.""" + + def __init__(self, v: int): + self.v = v + + @endpoint + async def value(self) -> int: + """Get the current counter value.""" + return self.v + + @endpoint + async def fail_me(self): + """Endpoint that always fails to test error handling.""" + raise RuntimeError("I was asked to fail") + + @endpoint + async def add_to_value(self, amount: int, multiplier: int = 1) -> int: + """Add an amount (optionally multiplied) to the current value.""" + logger.info(f"adding {amount} with {multiplier}") + self.v += amount * multiplier + return self.v + + @endpoint + async def incr(self): + """Increment the counter.""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter, batch_size=3, batch_timeout=1) + async def rr_batch_incr_bsize3(self): + """Increment the round-robin counter with batching (batch size = 3).""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) + async def rr_batch_incr_bsize5(self): + """Increment the round-robin counter with batching (batch size = 5).""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter) + async def rr_batch_incr_bsize1(self): + """Increment the round-robin counter with default batch_size=1.""" + self.v += 1 + + +def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: + """Helper to build a replica with specified state and load.""" + replica = Replica( + idx=idx, + proc_config=ProcessConfig(), + actor_def=Counter, + actor_args=(), + actor_kwargs={}, + ) + replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY + replica.active_requests = load + return replica + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_as_actor_preserves_normal_usage(): + """Ensure that using `as_actor` does not break normal semantics.""" + service = await Counter.as_actor(5) + + try: + assert await service.value.choose() == 5 + + # Test increment + await service.rr_batch_incr_bsize3.choose() + assert await service.value.choose() == 6 + + finally: + await Counter.shutdown(service) + + +@pytest.mark.asyncio +async def test_service_endpoint_router_and_configurations(): + """ + Verify service endpoints are registered with correct router/batching configuration: + - rr_batch_incr_bsize1: plain RoundRobinRouter, no batching (batch_size=1, timeout=0.01) + - rr_batch_incr_bsize3: Batcher wrapping RoundRobinRouter (batch_size=3, timeout=1) + - incr: plain @endpoint, should not appear in service.routers + """ + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # --- rr_batch_incr_bsize1 --- + router1 = service.routers.get("rr_batch_incr_bsize1") + assert isinstance( + router1, RoundRobinRouter + ), f"Expected RoundRobinRouter, got {type(router1)}" + + prop1 = Counter.rr_batch_incr_bsize1 + assert prop1.batch_size == 1 + assert prop1.batch_timeout == 0.01 + + # --- rr_batch_incr_bsize3 --- + router3 = service.routers.get("rr_batch_incr_bsize3") + + assert isinstance(router3, Batcher), f"Expected Batcher, got {type(router3)}" + assert router3.batch_size == 3 + assert router3.batch_timeout == 1 + + # --- incr --- + assert ( + "incr" not in service.routers + ), "Plain @endpoint should not be in service.routers" + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_noncallable(): + """@service_endpoint with non-callable router should raise ValueError.""" + + class BadActor(ForgeActor): + @service_endpoint(router="roundrobin") # string, not callable + async def bad_endpoint(self): + return 42 + + with pytest.raises(ValueError, match="Router must be callable"): + # Triggers ServiceInterface._set_router during construction + await BadActor.options(num_replicas=1).as_service() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_wrong_return_type(): + """@service_endpoint with callable that doesn't return Router should raise ValueError.""" + + class NotARouter: + """Dummy class that is not a Router.""" + + class BadActor(ForgeActor): + @service_endpoint(router=NotARouter) # returns NotARouter + async def bad_endpoint(self): + return 123 + + with pytest.raises(ValueError, match="Router must be a Router instance"): + await BadActor.options(num_replicas=1).as_service() + + +# Router Tests + + +@pytest.mark.asyncio +async def test_session_router_fallback_rr_vs_ll(): + """Switch fallback router to round-robin and verify assignment order.""" + # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = RoundRobinRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx != r2.idx + assert set(session_map.values()) == {0, 1} + + # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx == r2.idx == 0 + + +# Router integeration tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + for _ in range(6): + await service.incr.route() + values = await service.value.fanout() + results.append(values) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + final_values = results[-1] # last snapshot + assert sorted(final_values) == [2, 2, 2] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution_with_batching(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas with batch routing.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + tasks = [service.rr_batch_incr_bsize3.route() for _ in range(6)] + await asyncio.gather(*tasks) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + values = await service.value.fanout() + assert sorted(values) == [0, 3, 3] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map_in_service(): + """Integration: Service with SessionRouter preserves sticky sessions.""" + service = await Counter.options( + procs=1, + num_replicas=2, + ).as_service(v=0) + + try: + # First call with sess_id -> assign a replica + await service.incr.route(sess_id="sess1") + values1 = await service.value.fanout() + + # Second call with same sess_id -> must hit same replica + await service.incr.route(sess_id="sess1") + values2 = await service.value.fanout() + + # Difference should only be on one replica (sticky session) + diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] + assert ( + sum(diffs) == 1 + ), f"Expected exactly one replica to increment, got {diffs}" + assert max(diffs) == 1 and min(diffs) == 0 + + # Session map in service should reflect assigned replica + assigned_idx = service._session_replica_map["sess1"] + assert values2[assigned_idx] == values1[assigned_idx] + 1 + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_independent_batchers_and_routers_per_endpoint(): + """Ensure multiple @service_endpoint endpoints coexist with independent routers/batchers.""" + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # --- First batch: rr_batch_incr_bsize3 (batch_size = 3) --- + tasks = [ + asyncio.create_task(service.rr_batch_incr_bsize3.route()) for _ in range(4) + ] + await asyncio.gather(*tasks) + + values = await service.value.fanout() + + # Expectation: + # - First 3 requests form one batch → sent to replica R1 (+3). + # - Remaining 1 request forms its own batch → goes to replica R2 (+1). + # So totals should be [3, 1] (order depends on round robin). + assert sum(values) == 4, f"Expected total=4, got {values}" + assert sorted(values) == [1, 3], f"Expected [1, 3], got {values}" + + # --- Second batch: rr_batch_incr_bsize5 (batch_size = 5) --- + tasks = [ + asyncio.create_task(service.rr_batch_incr_bsize5.route()) for _ in range(7) + ] + await asyncio.gather(*tasks) + + values = await service.value.fanout() + + # Expectation (RoundRobin between replicas): + # Starting from previous state (R1=3, R2=1): + # - Next 5 requests form one batch → go to R1 (+5). + # - Remaining 2 requests form their own batch → go to R2 (+2). + # + # Final totals: + # R1 = 3 (previous) + 5 = 8 + # R2 = 1 (previous) + 2 = 3 + # So distribution should be [3, 8]. + assert sum(values) == 11, f"Expected total=11, got {values}" + assert sorted(values) == [3, 8], f"Expected [4, 8], got {values}" + + finally: + await service.shutdown() diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 31a912542..64882d243 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -13,15 +13,8 @@ import pytest from forge.controller import ForgeActor -from forge.controller.service import ( - LeastLoadedRouter, - Replica, - ReplicaState, - RoundRobinRouter, - ServiceConfig, - SessionRouter, -) -from forge.types import ProcessConfig +from forge.controller.service import ServiceConfig + from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) @@ -63,20 +56,6 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: return self.v -def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: - """Helper to build a replica with specified state and load.""" - replica = Replica( - idx=idx, - proc_config=ProcessConfig(), - actor_def=Counter, - actor_args=(), - actor_kwargs={}, - ) - replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY - replica.active_requests = load - return replica - - # Actor Tests @@ -754,95 +733,3 @@ async def test_broadcast_fanout_vs_route(): finally: await service.shutdown() - - -# Router Tests - - -@pytest.mark.asyncio -async def test_session_router_with_round_robin_fallback(): - """Switch fallback router to round-robin and verify assignment order.""" - # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = RoundRobinRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx != r2.idx - assert set(session_map.values()) == {0, 1} - - # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = LeastLoadedRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx == r2.idx == 0 - - -# Router integeration tests - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_round_robin_router_distribution(): - """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" - service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) - - try: - # Make multiple sessionless calls using route() - results = [] - for _ in range(6): - await service.incr.route() - values = await service.value.fanout() - print(values) - results.append(values) - print("results: ", results) - # Verify that requests were distributed round-robin - # Each call increments a single replica, so after 6 calls we expect: - # 2 increments per replica (since 3 replicas, 6 calls) - final_values = results[-1] # last snapshot - assert sorted(final_values) == [2, 2, 2] - - finally: - await service.shutdown() - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_session_router_assigns_and_updates_session_map_in_service(): - """Integration: Service with SessionRouter preserves sticky sessions.""" - # Use LeastLoaded as default, SessionRouter (with fallback) is always active - service = await Counter.options( - procs=1, - num_replicas=2, - ).as_service(v=0) - - try: - # First call with sess_id -> assign a replica - await service.incr.route(sess_id="sess1") - values1 = await service.value.fanout() - - # Second call with same sess_id -> must hit same replica - await service.incr.route(sess_id="sess1") - values2 = await service.value.fanout() - - # Difference should only be on one replica (sticky session) - diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] - assert ( - sum(diffs) == 1 - ), f"Expected exactly one replica to increment, got {diffs}" - assert max(diffs) == 1 and min(diffs) == 0 - - # Session map in service should reflect assigned replica - assigned_idx = service._session_replica_map["sess1"] - assert values2[assigned_idx] == values1[assigned_idx] + 1 - - finally: - await service.shutdown() From 6e351aafea5835e56f0fce00e9ab7a8faab3d183 Mon Sep 17 00:00:00 2001 From: DNXie Date: Sun, 5 Oct 2025 12:08:17 -0700 Subject: [PATCH 2/6] make batcher process one request per batch; TODO: kwargs and update tests --- src/forge/controller/service/replica.py | 12 ++-- src/forge/controller/service/router.py | 88 +++++++++++++++++++------ src/forge/controller/service/service.py | 37 ++++++----- 3 files changed, 92 insertions(+), 45 deletions(-) diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 81c6de0e1..9be98708a 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -159,9 +159,10 @@ async def initialize(self): # Deploy the actor and its underlying resources logger.debug(f"Launching actor for replica {self.idx}") - self.actor = await self.actor_def.options( - **asdict(self.proc_config) - ).as_actor(*self.actor_args, **self.actor_kwargs) + self.actor = await self.actor_def.launch( + *self.actor_args, + **self.actor_kwargs, + ) # Transition to healthy state and start processing self.state = ReplicaState.HEALTHY self.start_processing() @@ -220,11 +221,6 @@ async def enqueue_request(self, request: ServiceRequest): # Accept requests in all other states - let the processing loop handle the rest await self.request_queue.put(request) - async def enqueue_batch(self, requests: list[ServiceRequest]): - """Enqueues a batch of requests for processing by this replica.""" - for req in requests: - await self.enqueue_request(req) - async def _process_single_request(self, request: ServiceRequest) -> bool: """Processes a single request and returns success status. diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 11311f93f..276e45bff 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -146,13 +146,14 @@ class Batcher: def __init__( self, + function: str, inner_router: Router, get_healthy_replicas: Callable[[], List["Replica"]], get_session_map: Callable[[], Dict[str, int]], batch_size: int = 16, batch_timeout: float = 0.01, ): - + self.function = function self.inner_router = inner_router self.batch_size = batch_size self.batch_timeout = batch_timeout @@ -164,21 +165,31 @@ def __init__( self._running = True # flag to control loop # Background task that processes batches continuously self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) + # Maximum number of routing attempts per batch (1 initial + 1 retry if replica fails) + self._num_attempts = 2 - async def _batch_loop(self): + async def _batch_loop(self) -> None: """Background task that continuously processes batches of routing requests. This is the core batching logic that runs in a separate asyncio task. - It collects requests from the queue and processes them in batches based - on size and time constraints. + Each iteration collects individual (function, args, kwargs, future) entries from + the internal queue and merges them into a single `ServiceRequest` that is then + dispatched to one replica. The loop follows these steps: - 1. Wait for the first request to start a new batch - 2. Collect additional requests until batch_size or batch_timeout is reached + 1. Wait for the first queued call to start a new batch + 2. Continue collecting until batch_size or batch_timeout is reached 3. Make a single routing decision for the entire batch 4. Fulfill all futures with the selected replica This process repeats indefinitely until the task is cancelled. + + Returns: + None + + Raises: + RuntimeError: If no healthy replicas are available + Exception: Any exception raised by the actor endpoint """ while self._running: @@ -191,10 +202,10 @@ async def _batch_loop(self): timeout = max( 0, self.batch_timeout - (time.monotonic() - start_time) ) - req = await asyncio.wait_for( + nxt = await asyncio.wait_for( self._queue.get(), timeout ) # wait for timeout or until self._queue.get() finishes - batch.append(req) + batch.append(nxt) if len(batch) >= self.batch_size: break @@ -207,20 +218,59 @@ async def _batch_loop(self): # One routing decision for the whole batch replica = self.inner_router.get_replica(healthy_replicas, None, session_map) - # Send whole batch to replica - try: - await replica.enqueue_batch(batch) - except Exception as e: - for req in batch: - req.future.set_exception(e) + # Merge args for batched call + inputs = [b[1][0] for b in batch] - async def enqueue(self, request: ServiceRequest) -> Any: - """Enqueue request and wait until batch assigns a replica.""" - # Queue the request for batching - this is non-blocking - self._queue.put_nowait(request) + # One request for the entire batch + batch_req = ServiceRequest( + session_id=None, + function=self.function, + args=(inputs,), + kwargs={}, + future=asyncio.Future(), + ) + + for attempt in range(self._num_attempts): + try: + # Send whole batch to replica + await replica.enqueue_request(batch_req) + results = await batch_req.future + # Normalize result shape. + # The actor endpoint is expected to return one result per input + # If it instead returns a single scalar, replicate that scalar across + # all callers so that every waiting future gets a value. + if not isinstance(results, list) or len(results) != len(batch): + results = [results] * len(batch) + + for (_, _, _, f), r in zip(batch, results): + f.set_result(r) + break + except Exception as e: + if attempt < self._num_attempts - 1 and not replica.healthy: + logger.debug( + f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" + ) + healthy_replicas = self.get_healthy_replicas() + session_map = self.get_session_map() + if not healthy_replicas: + raise RuntimeError("No healthy replicas available") from e + replica = self.inner_router.get_replica( + healthy_replicas, None, session_map + ) + continue + else: + for _, _, _, f in batch: + if not f.done(): + f.set_exception(e) + + async def route(self, function: str, args: tuple, kwargs: dict) -> Any: + """Add (args, kwargs) pair to queue, return a Future resolved when batch completes.""" + # Queue the request for batching + fut = asyncio.Future() + self._queue.put_nowait((function, args, kwargs, fut)) # Wait for the batch processor to resolve our future - return await request.future + return await fut async def stop(self): """Stop the batch loop gracefully.""" diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 6c182b5ae..d1b05284b 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -184,6 +184,7 @@ def _set_router( # Wrap in Batcher if batching requested if batch_size > 1: router = Batcher( + endpoint_name, base_router, get_healthy_replicas=self._get_healthy_replicas, get_session_map=self._get_session_map, @@ -227,6 +228,13 @@ async def _route(self, sess_id: str | None, function: str, *args, **kwargs): router = self.routers.get(function, self._default_router) + # Case 1: batching is enabled and no session ID (stateless calls only) + # Forward the request into the Batcher queue. The batcher will send + # the batched requests to the selected replica. + if sess_id is None and isinstance(router, Batcher): + return await router.route(function, args, kwargs) + + # Case 2: route a single request to a replica # Create a ServiceRequest object to queue request = ServiceRequest( session_id=sess_id, @@ -236,27 +244,20 @@ async def _route(self, sess_id: str | None, function: str, *args, **kwargs): future=asyncio.Future(), ) - # Case: batching is enabled and no session ID (stateless calls only) - # Forward the request into the Batcher queue. The batcher will send - # the batched requests to the selected replica. - if sess_id is None and isinstance(router, Batcher): - await router.enqueue(request) + healthy_replicas = [r for r in self._replicas if r.healthy] + # Select replica: + if sess_id is not None: + # Case 2.1: sticky sessions + replica = self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) else: - healthy_replicas = [r for r in self._replicas if r.healthy] - - # Select replica: - if sess_id is not None: - # Case 1: sticky sessions - replica = self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - else: - # Case 2: stateless routing - replica = self._default_router.get_replica(healthy_replicas) + # Case 2.2: stateless routing + replica = self._default_router.get_replica(healthy_replicas) - # Queue the request using replica's method - await replica.enqueue_request(request) + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: From e7653b511d5b116cf31c4484b97518f49cb19f51 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 6 Oct 2025 10:16:00 -0700 Subject: [PATCH 3/6] add tmp test --- tests/unit_tests/test_router.py | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 4ce497e77..6fc74e1a2 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -325,3 +325,54 @@ async def test_independent_batchers_and_routers_per_endpoint(): finally: await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_rr_batch_incr_bsize5_behaves_like_normal_incr(): + """Ensure rr_batch_incr_bsize5 (batch_size=5) behaves like a normal incr endpoint for single calls.""" + service = await Counter.options(procs=1, num_replicas=1).as_service(v=5) + + try: + # Initial value + assert await service.value.route() == 5 + + # Call batched increment once + await service.rr_batch_incr_bsize5.route() + + # Should increment exactly once + assert await service.value.route() == 6 + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +async def test_service_endpoint_batching_preserves_order(): + class MyActor(ForgeActor): + def __init__(self): + self._num_calls = 0 + self._sum = 0 + + @endpoint + async def get_num_calls(self): + return self._num_calls + + @endpoint + async def get_sum(self): + return self._sum + + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) + async def test(self, inputs: list[int]): + self._num_calls += 1 + self._sum += sum(inputs) + return inputs + + service = await MyActor.options(num_replicas=2, procs=1).as_service() + try: + results = await asyncio.gather(*[service.test.route(i) for i in range(5)]) + assert results == [0, 1, 2, 3, 4] + assert await service.get_num_calls.route() == 1 + assert sorted(await service.get_sum.fanout()) == [0, 10] + finally: + await service.shutdown() From aea611621f33dc84ce74af1e10007a63aac2d10b Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 6 Oct 2025 19:12:25 -0700 Subject: [PATCH 4/6] add a todo --- src/forge/controller/service/router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 276e45bff..b07ef9c05 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -197,6 +197,7 @@ async def _batch_loop(self) -> None: batch = [await self._queue.get()] start_time = time.monotonic() + # TODO (dxie): consider making timeout adaptive based on replica load. while True: try: timeout = max( From 664ef41efb7d04e00491ee0b6c46c5048948d068 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 8 Oct 2025 11:57:03 -0700 Subject: [PATCH 5/6] update tests and code --- src/forge/controller/service/router.py | 30 +++- tests/unit_tests/test_router.py | 221 ++++++++++++++----------- 2 files changed, 143 insertions(+), 108 deletions(-) diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index b07ef9c05..6c179ec10 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -220,13 +220,18 @@ async def _batch_loop(self) -> None: replica = self.inner_router.get_replica(healthy_replicas, None, session_map) # Merge args for batched call - inputs = [b[1][0] for b in batch] + if batch and len(batch[0][1]) > 0: + # Normal case: endpoints expect positional arguments + merged_args = [list(items) for items in zip(*[b[1] for b in batch])] + args = tuple(merged_args) + else: + # No-arg case: just one batched call, no inputs to merge + args = () - # One request for the entire batch batch_req = ServiceRequest( session_id=None, function=self.function, - args=(inputs,), + args=args, kwargs={}, future=asyncio.Future(), ) @@ -234,17 +239,24 @@ async def _batch_loop(self) -> None: for attempt in range(self._num_attempts): try: # Send whole batch to replica + logger.debug( + f"[Batcher] enqueue request executing function={batch_req.function}, args={batch_req.args}, kwargs={batch_req.kwargs}" + ) await replica.enqueue_request(batch_req) results = await batch_req.future - # Normalize result shape. - # The actor endpoint is expected to return one result per input - # If it instead returns a single scalar, replicate that scalar across - # all callers so that every waiting future gets a value. - if not isinstance(results, list) or len(results) != len(batch): + logger.debug(f"[Batcher] results: {results}") + # Normalize result shape + if isinstance(results, (list, tuple)): + results = list(results) + if len(results) != len(batch): + results = [results] * len(batch) + else: results = [results] * len(batch) + # Fulfill each individual Future from the batch for (_, _, _, f), r in zip(batch, results): - f.set_result(r) + if not f.done(): + f.set_result(r) break except Exception as e: diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 6fc74e1a2..9d0522f93 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -34,6 +34,7 @@ class Counter(ForgeActor): def __init__(self, v: int): self.v = v + self._num_calls = 0 # number of calls to endpoint functions @endpoint async def value(self) -> int: @@ -46,31 +47,35 @@ async def fail_me(self): raise RuntimeError("I was asked to fail") @endpoint - async def add_to_value(self, amount: int, multiplier: int = 1) -> int: - """Add an amount (optionally multiplied) to the current value.""" - logger.info(f"adding {amount} with {multiplier}") - self.v += amount * multiplier - return self.v + async def get_num_calls(self): + """Get the number of calls to endpoint functions.""" + return self._num_calls @endpoint async def incr(self): """Increment the counter.""" + self._num_calls += 1 self.v += 1 @service_endpoint(router=RoundRobinRouter, batch_size=3, batch_timeout=1) async def rr_batch_incr_bsize3(self): """Increment the round-robin counter with batching (batch size = 3).""" + self._num_calls += 1 self.v += 1 @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) - async def rr_batch_incr_bsize5(self): + async def rr_batch_incr_bsize5(self, inputs: list[int]) -> list[int]: """Increment the round-robin counter with batching (batch size = 5).""" - self.v += 1 + self._num_calls += 1 + self.v += sum(inputs) + return inputs @service_endpoint(router=RoundRobinRouter) - async def rr_batch_incr_bsize1(self): - """Increment the round-robin counter with default batch_size=1.""" - self.v += 1 + async def rr_batch_incr_bsize1(self, inputs: list[int]) -> list[int]: + """Increment the round-robin counter with batching (batch size = 1).""" + self._num_calls += 1 + self._sum += sum(inputs) + return inputs def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: @@ -87,21 +92,7 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: return replica -@pytest.mark.asyncio -@pytest.mark.timeout(10) -async def test_service_as_actor_preserves_normal_usage(): - """Ensure that using `as_actor` does not break normal semantics.""" - service = await Counter.as_actor(5) - - try: - assert await service.value.choose() == 5 - - # Test increment - await service.rr_batch_incr_bsize3.choose() - assert await service.value.choose() == 6 - - finally: - await Counter.shutdown(service) +# Cnofig tests @pytest.mark.asyncio @@ -227,27 +218,6 @@ async def test_round_robin_router_distribution(): await service.shutdown() -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_round_robin_router_distribution_with_batching(): - """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas with batch routing.""" - service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) - - try: - # Make multiple sessionless calls using route() - results = [] - tasks = [service.rr_batch_incr_bsize3.route() for _ in range(6)] - await asyncio.gather(*tasks) - # Verify that requests were distributed round-robin - # Each call increments a single replica, so after 6 calls we expect: - # 2 increments per replica (since 3 replicas, 6 calls) - values = await service.value.fanout() - assert sorted(values) == [0, 3, 3] - - finally: - await service.shutdown() - - @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_session_router_assigns_and_updates_session_map_in_service(): @@ -281,50 +251,24 @@ async def test_session_router_assigns_and_updates_session_map_in_service(): await service.shutdown() -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_independent_batchers_and_routers_per_endpoint(): - """Ensure multiple @service_endpoint endpoints coexist with independent routers/batchers.""" - service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) - - try: - # --- First batch: rr_batch_incr_bsize3 (batch_size = 3) --- - tasks = [ - asyncio.create_task(service.rr_batch_incr_bsize3.route()) for _ in range(4) - ] - await asyncio.gather(*tasks) - - values = await service.value.fanout() +# Batcher tests - # Expectation: - # - First 3 requests form one batch → sent to replica R1 (+3). - # - Remaining 1 request forms its own batch → goes to replica R2 (+1). - # So totals should be [3, 1] (order depends on round robin). - assert sum(values) == 4, f"Expected total=4, got {values}" - assert sorted(values) == [1, 3], f"Expected [1, 3], got {values}" - # --- Second batch: rr_batch_incr_bsize5 (batch_size = 5) --- - tasks = [ - asyncio.create_task(service.rr_batch_incr_bsize5.route()) for _ in range(7) - ] - await asyncio.gather(*tasks) +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_as_actor_preserves_normal_usage(): + """Ensure that using `as_actor` does not break normal semantics.""" + service = await Counter.as_actor(5) - values = await service.value.fanout() + try: + assert await service.value.choose() == 5 - # Expectation (RoundRobin between replicas): - # Starting from previous state (R1=3, R2=1): - # - Next 5 requests form one batch → go to R1 (+5). - # - Remaining 2 requests form their own batch → go to R2 (+2). - # - # Final totals: - # R1 = 3 (previous) + 5 = 8 - # R2 = 1 (previous) + 2 = 3 - # So distribution should be [3, 8]. - assert sum(values) == 11, f"Expected total=11, got {values}" - assert sorted(values) == [3, 8], f"Expected [4, 8], got {values}" + # Test increment + await service.rr_batch_incr_bsize3.choose() + assert await service.value.choose() == 6 finally: - await service.shutdown() + await Counter.shutdown(service) @pytest.mark.timeout(10) @@ -338,7 +282,7 @@ async def test_rr_batch_incr_bsize5_behaves_like_normal_incr(): assert await service.value.route() == 5 # Call batched increment once - await service.rr_batch_incr_bsize5.route() + await service.rr_batch_incr_bsize5.route(1) # Should increment exactly once assert await service.value.route() == 6 @@ -348,31 +292,110 @@ async def test_rr_batch_incr_bsize5_behaves_like_normal_incr(): @pytest.mark.asyncio +@pytest.mark.timeout(10) async def test_service_endpoint_batching_preserves_order(): - class MyActor(ForgeActor): + """Ensure that batching preserves the order of calls.""" + service = await Counter.options(num_replicas=2, procs=1).as_service(0) + try: + results = await asyncio.gather( + *[service.rr_batch_incr_bsize5.route(i) for i in range(5)] + ) + assert results == [0, 1, 2, 3, 4] + assert await service.get_num_calls.route() == 1 + assert sorted(await service.value.fanout()) == [0, 10] + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_endpoint_multiple_batches(): + """ + Verify that batching correctly splits requests into two batches — + one triggered by reaching the batch size limit and another by the batch timeout. + """ + service = await Counter.options(num_replicas=2, procs=1).as_service(0) + try: + # Enqueue 7 calls → expect two batches (5 + 2) + results = await asyncio.gather( + *[service.rr_batch_incr_bsize5.route(i) for i in range(7)] + ) + # Verify all individual results were returned in order + assert results == [0, 1, 2, 3, 4, 5, 6] + # Each replica should have executed one batch (round-robin) + assert await service.get_num_calls.fanout() == [1, 1] + + # Replica values reflect the sum of their respective batch inputs + # first batch: [0, 1, 2, 3, 4] → 10 + # second batch: [5, 6] → 11 + assert sorted(await service.value.fanout()) == [10, 11] + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_round_robin_batcher_distribution_no_args(): + """ + Verify that the batching system correctly handles endpoints with **zero arguments** + and that the RoundRobinRouter distributes such batched calls evenly across replicas. + """ + + # --- Launch service with 3 replicas --- + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Enqueue 5 no-arg batched calls + await asyncio.gather(*[service.rr_batch_incr_bsize3.route() for _ in range(5)]) + + # Check that two replicas incremented their counters once + values = await service.value.fanout() + assert sorted(values) == [0, 1, 1], f"Unexpected replica values: {values}" + + # Ensure exactly 2 actor invocations occurred (2 batches total) + num_calls = await service.get_num_calls.fanout() + assert sum(num_calls) == 2, f"Expected 2 batches, got {sum(num_calls)}" + + finally: + await service.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.timeout(10) +async def test_service_endpoint_batching_multi_arg_merge(): + """Ensure that batching merges multiple argument lists correctly.""" + + class MultiArgActor(ForgeActor): def __init__(self): self._num_calls = 0 - self._sum = 0 @endpoint async def get_num_calls(self): return self._num_calls - @endpoint - async def get_sum(self): - return self._sum - - @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) - async def test(self, inputs: list[int]): + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.1) + async def multi_args_sum(self, v1: list[int], v2: list[str]) -> list[str]: + """ + Endpoint that accepts multiple argument lists. + Should be invoked once per batch. + """ self._num_calls += 1 - self._sum += sum(inputs) - return inputs + # Combine corresponding elements + return [f"{x}:{y}" for x, y in zip(v1, v2)] + + service = await MultiArgActor.options(num_replicas=2, procs=1).as_service() - service = await MyActor.options(num_replicas=2, procs=1).as_service() try: - results = await asyncio.gather(*[service.test.route(i) for i in range(5)]) - assert results == [0, 1, 2, 3, 4] + # 5 requests will fill one batch of size 5 + results = await asyncio.gather( + *[service.multi_args_sum.route(i, str(i)) for i in range(5)] + ) + + # Expect exactly one actor invocation assert await service.get_num_calls.route() == 1 - assert sorted(await service.get_sum.fanout()) == [0, 10] + + # Expect results correspond to all merged pairs + assert results == ["0:0", "1:1", "2:2", "3:3", "4:4"] + finally: await service.shutdown() From abd29a572c158a918bb08b9d84596b62d1ae7ccd Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 8 Oct 2025 12:31:59 -0700 Subject: [PATCH 6/6] clean up the code --- src/forge/controller/service/endpoint.py | 2 +- src/forge/controller/service/replica.py | 2 +- src/forge/controller/service/router.py | 48 ++++++++++-------------- 3 files changed, 21 insertions(+), 31 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index c731017c6..ba29034b0 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -182,7 +182,7 @@ def service_endpoint( Example: class MyForgeActor(ForgeActor): - @service_endpoint(router=RoundRobinRouter(), batch_size=16, batch_timeout=0.05) + @service_endpoint(router=RoundRobinRouter, batch_size=16, batch_timeout=0.05) async def predict(self, x): ... """ diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 9be98708a..09b0a2ce6 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -9,7 +9,7 @@ import logging import time from collections import deque -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from enum import Enum from typing import Optional diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 6c179ec10..76456b0f1 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -138,10 +138,8 @@ class Batcher: batch_timeout=0.01, ) - request = ServiceRequest(...) - - # Enqueue a request to be sent to a replica - await batcher.enqueue(request) + # Enqueue a endpoint call to be sent to a replica + results = await batcher.route(function, args, kwargs) """ def __init__( @@ -213,12 +211,6 @@ async def _batch_loop(self) -> None: except asyncio.TimeoutError: break - session_map = self.get_session_map() - healthy_replicas = self.get_healthy_replicas() - - # One routing decision for the whole batch - replica = self.inner_router.get_replica(healthy_replicas, None, session_map) - # Merge args for batched call if batch and len(batch[0][1]) > 0: # Normal case: endpoints expect positional arguments @@ -228,29 +220,34 @@ async def _batch_loop(self) -> None: # No-arg case: just one batched call, no inputs to merge args = () - batch_req = ServiceRequest( - session_id=None, - function=self.function, - args=args, - kwargs={}, - future=asyncio.Future(), - ) - for attempt in range(self._num_attempts): + session_map = self.get_session_map() + healthy_replicas = self.get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.inner_router.get_replica( + healthy_replicas, None, session_map + ) + + batch_req = ServiceRequest( + session_id=None, + function=self.function, + args=args, + kwargs={}, + future=asyncio.Future(), + ) + try: # Send whole batch to replica - logger.debug( - f"[Batcher] enqueue request executing function={batch_req.function}, args={batch_req.args}, kwargs={batch_req.kwargs}" - ) await replica.enqueue_request(batch_req) results = await batch_req.future - logger.debug(f"[Batcher] results: {results}") # Normalize result shape if isinstance(results, (list, tuple)): results = list(results) if len(results) != len(batch): results = [results] * len(batch) else: + # scalar result, broadcast to batch size results = [results] * len(batch) # Fulfill each individual Future from the batch @@ -264,13 +261,6 @@ async def _batch_loop(self) -> None: logger.debug( f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" ) - healthy_replicas = self.get_healthy_replicas() - session_map = self.get_session_map() - if not healthy_replicas: - raise RuntimeError("No healthy replicas available") from e - replica = self.inner_router.get_replica( - healthy_replicas, None, session_map - ) continue else: for _, _, _, f in batch: