From 2c349fdc9f876257d3e241b5432184c61a57a6d1 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 24 Sep 2025 11:45:45 -0700 Subject: [PATCH 01/32] router access latest healith_replicas and sessionmap --- src/forge/controller/service/router.py | 133 ++++++++++++++++++- tests/unit_tests/test_service.py | 170 +++++++++++++++++++++++++ 2 files changed, 302 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 502402e36..80faadc16 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Dict, List +from typing import Callable, Dict, List from .interface import Router from .replica import Replica @@ -88,3 +88,134 @@ def get_replica( replica.idx, ) return replica + + +class BatchRouter(Router): + """ + Router wrapper that batches routing decisions. + Uses an inner router to pick the replica for each batch. + + Args: + inner_router: The underlying Router instance used to make routing decisions + batch_max_size: Maximum number of requests to collect in a single batch (default: 8) + batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01) + + Example: + rr_router = RoundRobinRouter() + batch_router = BatchRouter(rr_router, batch_max_size=16, batch_max_wait_s=0.02) + + replica = await batch_router.get_replica(healthy_replicas, sess_id, session_map) + """ + + def __init__( + self, + inner_router: Router, + batch_max_size: int = 8, + batch_max_wait_s: float = 0.01, + get_healthy_replicas: Optional[Callable[[], List["Replica"]]] = None, + session_map: Optional[Dict[str, int]] = None, + ): + + self.inner_router = inner_router + self.batch_max_size = batch_max_size + self.batch_max_wait_s = batch_max_wait_s + self.get_healthy_replicas = get_healthy_replicas + self.session_map = session_map + + # Internal queue for batching routing requests + self._queue: asyncio.Queue = asyncio.Queue() + self._running = True # flag to control loop + # Background task that processes batches continuously + self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_max_size or batch_max_wait_s is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running: + batch = [] + futs = [] + sess_ids = [] + + # Wait for first request + fut, healthy_replicas, sess_id, session_map = await self._queue.get() + batch.append((healthy_replicas, sess_id, session_map)) + futs.append(fut) + sess_ids.append(sess_id) + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self.batch_max_wait_s - (time.monotonic() - start_time) + ) + ( + fut, + healthy_replicas, + sess_id, + session_map, + ) = await asyncio.wait_for( + self._queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch.append((healthy_replicas, sess_id, session_map)) + futs.append(fut) + sess_ids.append(sess_id) + + if len(batch) >= self.batch_max_size: + break + except asyncio.TimeoutError: + break + + if self.session_map is not None: + session_map = self.session_map + else: + session_map = batch[-1][2] # use most recent session map + if self.get_healthy_replicas is not None: + healthy_replicas = self.get_healthy_replicas() + else: + healthy_replicas = batch[-1][0] # use most recent replica state + # Check if any replicas have become unhealthy + healthy_replicas = [r for r in healthy_replicas if r.healthy] + + # One routing decision for the whole batch + replica = await self.inner_router.get_replica( + healthy_replicas, None, session_map + ) + + # Fulfill all futures with the chosen replica + for fut in futs: + fut.set_result(replica) + + async def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: Optional[str] = None, + session_map: Optional[Dict[str, int]] = None, + ) -> Replica: + """Enqueue request and wait until batch assigns a replica.""" + fut = asyncio.Future() + # Queue the request for batching - this is non-blocking + self._queue.put_nowait((fut, healthy_replicas, sess_id, session_map)) + + # Wait for the batch processor to resolve our future + return await fut + + async def shutdown(self): + """Stop the batch loop gracefully.""" + self._running = False + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 31a912542..4249f9096 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -786,6 +786,176 @@ async def test_session_router_with_round_robin_fallback(): assert r1.idx == r2.idx == 0 +@pytest.mark.asyncio +async def test_batching_router_batchsize_with_roundrobin(): + """Batch should flush when max batch size is reached using RoundRobinRouter.""" + replicas = [make_replica(0), make_replica(1)] + batch_size = 3 + + router = BatchRouter( + RoundRobinRouter(), + batch_max_size=batch_size, + batch_max_wait_s=0.5, # long enough to not trigger timeout + ) + + try: + # Enqueue `batch_size + 1` requests to force batch flush + tasks = [ + asyncio.create_task(router.get_replica(replicas)) + for _ in range(batch_size + 1) + ] + results = await asyncio.gather(*tasks) + + # Check all results are healthy replicas + assert all(r.state == ReplicaState.HEALTHY for r in results) + + # Check results only use existing replica indices + indices = {r.idx for r in results} + assert indices.issubset({0, 1}) + + # Ensure batch queue is empty after flush + assert router._queue.qsize() == 0 + finally: + router.shutdown() + + +@pytest.mark.asyncio +async def test_batching_router_skips_unhealthy_replicas(): + """If a replica becomes unhealthy before batch dispatch, it should be skipped.""" + replicas = [make_replica(0, load=0), make_replica(1, load=10)] + + router = BatchRouter( + LeastLoadedRouter(), + batch_max_size=4, + batch_max_wait_s=0.5, + ) + try: + # Start two requests that will form a batch + tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(2)] + + # While they are waiting, mark replica 0 (least loaded) as unhealthy + await asyncio.sleep(0.01) + replicas[0].state = ReplicaState.UNHEALTHY + + results = await asyncio.gather(*tasks) + + # All results must be the *healthy* replica (idx=1) + assert all(r.idx == 1 for r in results) + assert results[0].state == ReplicaState.HEALTHY + finally: + router.shutdown() + + +@pytest.mark.asyncio +async def test_batching_router_two_batches_timing(): + """Test that two sequential batches are processed independently with proper timing.""" + import time + + replicas = [make_replica(0, load=5), make_replica(1, load=10)] + batch_wait_time = 0.05 # 50ms timeout + + router = BatchRouter( + LeastLoadedRouter(), + batch_max_size=3, + batch_max_wait_s=batch_wait_time, + ) + try: + # First batch: 2 requests that will timeout + start_time = time.time() + + # Create first batch tasks + first_batch_tasks = [ + asyncio.create_task(router.get_replica(replicas)) for _ in range(2) + ] + + # Wait for first batch to complete (should timeout after batch_wait_time) + first_results = await asyncio.gather(*first_batch_tasks) + first_batch_duration = time.time() - start_time + + # Verify first batch took approximately the timeout duration (tighter tolerance) + assert ( + batch_wait_time <= first_batch_duration < batch_wait_time + 0.01 + ) # 10ms tolerance on 50ms timeout + + # Verify first batch results (should pick lowest load replica) + assert all(r.idx == 0 for r in first_results) # replica 0 has lower load + assert all(r.state == ReplicaState.HEALTHY for r in first_results) + + # Second batch: 2 more requests (new timing cycle should start) + second_batch_start = time.time() + + # Create second batch tasks + second_batch_tasks = [ + asyncio.create_task(router.get_replica(replicas)) for _ in range(2) + ] + + # Wait for second batch to complete + second_results = await asyncio.gather(*second_batch_tasks) + second_batch_duration = time.time() - second_batch_start + + # Verify second batch also took approximately the timeout duration (tighter tolerance) + assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01 + + # Verify second batch results + assert all(r.idx == 0 for r in second_results) # should still pick lowest load + assert all(r.state == ReplicaState.HEALTHY for r in second_results) + + # Ensure batch queue is empty after both batches + assert router._queue.qsize() == 0 + finally: + router.shutdown() + + +@pytest.mark.asyncio +async def test_batchrouter_callable_updates(): + """Test that callables reflect updates after a batch is processed.""" + + # Initial replicas and session map + replicas = [make_replica(0, load=0), make_replica(1, load=10)] + session_map = {} + + # Define dynamic callable for healthy replicas + def get_healthy_replicas(): + # Always return only replicas whose state is HEALTHY + return [r for r in replicas if r.healthy] + + # Wrap inner router to spy on session_map received + class SpyRouter(LeastLoadedRouter): + async def get_replica(self, healthy_replicas, sess_id, session_map_arg): + # Save the session_map passed in + self.last_session_map = session_map_arg + return await super().get_replica(healthy_replicas, sess_id, session_map_arg) + + # Router using callables + spy_inner_router = SpyRouter() + router = BatchRouter( + spy_inner_router, + batch_max_size=2, + batch_max_wait_s=0.05, + get_healthy_replicas=get_healthy_replicas, + session_map=session_map, + ) + + try: + # Mark replica 0 unhealthy *after* we enqueue the first batch + tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(1)] + await asyncio.sleep(0.001) + replicas[0].state = ReplicaState.UNHEALTHY + session_map["s1"] = 42 # simulate an update while batch is pending + + # Wait for batch to complete + results = await asyncio.gather(*tasks) + + # Verify router used healthy replica (idx=1) instead of least-loaded one (idx=0) + assert all(r.idx == 1 for r in results) + + # Verify router actually received the updated session_map + assert spy_inner_router.last_session_map["s1"] == 42 + + finally: + await router.shutdown() + + # Router integeration tests From d35193587e901b4c32f445d051c6e2133d3b5f39 Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 24 Sep 2025 12:22:43 -0700 Subject: [PATCH 02/32] fix test --- src/forge/controller/service/__init__.py | 3 ++- src/forge/controller/service/router.py | 3 ++- tests/unit_tests/test_service.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index f0d8fca7b..803e06dd3 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -7,7 +7,7 @@ from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState -from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter +from .router import BatchRouter, LeastLoadedRouter, RoundRobinRouter, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ @@ -24,4 +24,5 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", + "BatchRouter", ] diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 80faadc16..5b2804575 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional from .interface import Router from .replica import Replica diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 4249f9096..3154dee88 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -14,6 +14,7 @@ import pytest from forge.controller import ForgeActor from forge.controller.service import ( + BatchRouter, LeastLoadedRouter, Replica, ReplicaState, From 1c7efacc7df8f1be87a2437e7f746de4bec11b5a Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 24 Sep 2025 14:30:26 -0700 Subject: [PATCH 03/32] add batch routing logic to service + test case --- src/forge/controller/service/__init__.py | 3 +- src/forge/controller/service/router.py | 135 +-------------- src/forge/controller/service/service.py | 100 +++++++++-- src/forge/types.py | 2 + tests/unit_tests/test_service.py | 203 ++++------------------- 5 files changed, 123 insertions(+), 320 deletions(-) diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index 803e06dd3..f0d8fca7b 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -7,7 +7,7 @@ from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState -from .router import BatchRouter, LeastLoadedRouter, RoundRobinRouter, SessionRouter +from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ @@ -24,5 +24,4 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", - "BatchRouter", ] diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 5b2804575..1dd0308e0 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -4,9 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import asyncio + import logging -from typing import Callable, Dict, List, Optional +from typing import Dict, List from .interface import Router from .replica import Replica @@ -89,134 +89,3 @@ def get_replica( replica.idx, ) return replica - - -class BatchRouter(Router): - """ - Router wrapper that batches routing decisions. - Uses an inner router to pick the replica for each batch. - - Args: - inner_router: The underlying Router instance used to make routing decisions - batch_max_size: Maximum number of requests to collect in a single batch (default: 8) - batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01) - - Example: - rr_router = RoundRobinRouter() - batch_router = BatchRouter(rr_router, batch_max_size=16, batch_max_wait_s=0.02) - - replica = await batch_router.get_replica(healthy_replicas, sess_id, session_map) - """ - - def __init__( - self, - inner_router: Router, - batch_max_size: int = 8, - batch_max_wait_s: float = 0.01, - get_healthy_replicas: Optional[Callable[[], List["Replica"]]] = None, - session_map: Optional[Dict[str, int]] = None, - ): - - self.inner_router = inner_router - self.batch_max_size = batch_max_size - self.batch_max_wait_s = batch_max_wait_s - self.get_healthy_replicas = get_healthy_replicas - self.session_map = session_map - - # Internal queue for batching routing requests - self._queue: asyncio.Queue = asyncio.Queue() - self._running = True # flag to control loop - # Background task that processes batches continuously - self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) - - async def _batch_loop(self): - """Background task that continuously processes batches of routing requests. - - This is the core batching logic that runs in a separate asyncio task. - It collects requests from the queue and processes them in batches based - on size and time constraints. - - The loop follows these steps: - 1. Wait for the first request to start a new batch - 2. Collect additional requests until batch_max_size or batch_max_wait_s is reached - 3. Make a single routing decision for the entire batch - 4. Fulfill all futures with the selected replica - - This process repeats indefinitely until the task is cancelled. - """ - while self._running: - batch = [] - futs = [] - sess_ids = [] - - # Wait for first request - fut, healthy_replicas, sess_id, session_map = await self._queue.get() - batch.append((healthy_replicas, sess_id, session_map)) - futs.append(fut) - sess_ids.append(sess_id) - start_time = time.monotonic() - - while True: - try: - timeout = max( - 0, self.batch_max_wait_s - (time.monotonic() - start_time) - ) - ( - fut, - healthy_replicas, - sess_id, - session_map, - ) = await asyncio.wait_for( - self._queue.get(), timeout - ) # wait for timeout or until self._queue.get() finishes - batch.append((healthy_replicas, sess_id, session_map)) - futs.append(fut) - sess_ids.append(sess_id) - - if len(batch) >= self.batch_max_size: - break - except asyncio.TimeoutError: - break - - if self.session_map is not None: - session_map = self.session_map - else: - session_map = batch[-1][2] # use most recent session map - if self.get_healthy_replicas is not None: - healthy_replicas = self.get_healthy_replicas() - else: - healthy_replicas = batch[-1][0] # use most recent replica state - # Check if any replicas have become unhealthy - healthy_replicas = [r for r in healthy_replicas if r.healthy] - - # One routing decision for the whole batch - replica = await self.inner_router.get_replica( - healthy_replicas, None, session_map - ) - - # Fulfill all futures with the chosen replica - for fut in futs: - fut.set_result(replica) - - async def get_replica( - self, - healthy_replicas: List[Replica], - sess_id: Optional[str] = None, - session_map: Optional[Dict[str, int]] = None, - ) -> Replica: - """Enqueue request and wait until batch assigns a replica.""" - fut = asyncio.Future() - # Queue the request for batching - this is non-blocking - self._queue.put_nowait((fut, healthy_replicas, sess_id, session_map)) - - # Wait for the batch processor to resolve our future - return await fut - - async def shutdown(self): - """Stop the batch loop gracefully.""" - self._running = False - self._batch_task.cancel() - try: - await self._batch_task - except asyncio.CancelledError: - pass diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 0b655fb6a..327ffda3d 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -35,6 +35,7 @@ import asyncio import logging import pprint +import time import uuid from typing import Dict, List @@ -110,6 +111,13 @@ async def __initialize__(self): self._default_router = RoundRobinRouter() self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) + # Batching + self._max_batch_size = self._cfg.max_batch_size + self._batch_max_wait_s = self._cfg.batch_max_wait_s + self._batch_task: asyncio.Task | None = None + self._running_batch_loop = False + self._batch_queue: asyncio.Queue = asyncio.Queue() + # Initialize all replicas replicas = [] num_replicas = self._cfg.num_replicas @@ -138,6 +146,60 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) + # Start batch loop if batching enabled + if self._max_batch_size > 1: + self._running_batch_loop = True + self._batch_task = asyncio.create_task(self._batch_loop()) + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_max_size or batch_max_wait_s is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running_batch_loop: + batch_futs = [] + + # Wait for first request + fut = await self._batch_queue.get() + batch_futs.append(fut) + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self._batch_max_wait_s - (time.monotonic() - start_time) + ) + fut = await asyncio.wait_for( + self._batch_queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch_futs.append(fut) + + if len(batch_futs) >= self._max_batch_size: + break + except asyncio.TimeoutError: + break + + healthy_replicas = self._get_healthy_replicas() + + # One routing decision for the whole batch + replica = self._default_router.get_replica( + healthy_replicas, None, self._session_replica_map + ) + + # Fulfill all futures with the chosen replica + for fut in batch_futs: + fut.set_result(replica) + async def _call(self, sess_id: str | None, function: str, *args, **kwargs): """ Routes a function call to the appropriate replica with load balancing and fault tolerance. @@ -211,7 +273,7 @@ async def call_all(self, function: str, *args, **kwargs) -> List: Raises: RuntimeError: If no healthy replicas are available """ - healthy_replicas = [r for r in self._replicas if r.healthy] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: raise RuntimeError("No healthy replicas available for broadcast call") @@ -280,9 +342,7 @@ async def _migrate_remaining_requests(self, failed_replica: Replica): ) # Find healthy replicas - healthy_replicas = [ - r for r in self._replicas if r.healthy and r != failed_replica - ] + healthy_replicas = self._get_healthy_replicas() if not healthy_replicas: # No healthy replicas, fail all requests @@ -334,7 +394,7 @@ def _update_service_metrics(self): """Updates service-level metrics.""" self._metrics.total_sessions = len(self._active_sessions) self._metrics.total_replicas = len(self._replicas) - self._metrics.healthy_replicas = sum(1 for r in self._replicas if r.healthy) + self._metrics.healthy_replicas = len(self._get_healthy_replicas()) # Store direct references to replica metrics for aggregation self._metrics.replica_metrics = {} for replica in self._replicas: @@ -446,6 +506,10 @@ async def terminate_session(self, sess_id: str): # Update metrics self._update_service_metrics() + def _get_healthy_replicas(self) -> list[Replica]: + """Returns a list of healthy replicas.""" + return [r for r in self._replicas if r.healthy] + async def _health_loop(self, poll_rate_s: float): """Runs the health loop to monitor and recover replicas. @@ -476,14 +540,24 @@ async def _health_loop(self, poll_rate_s: float): async def _get_replica(self, sess_id: str | None) -> "Replica": """Get a replica for the given session ID.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - if sess_id is None: - # No session, use the default router - return self._default_router.get_replica(healthy_replicas) - return self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) + if sess_id: + # Stateful routing always uses session router + healthy_replicas = self._get_healthy_replicas() + return self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + + # Stateless: batching + if self._max_batch_size > 1: + fut = asyncio.Future() + healthy_replicas = self._get_healthy_replicas() + self._batch_queue.put_nowait(fut) + return await fut + else: + # No batching, pick immediately + healthy_replicas = self._get_healthy_replicas() + return self._default_router.get_replica(healthy_replicas) async def stop(self): logger.debug("Stopping service...") @@ -582,7 +656,7 @@ async def _get_internal_state(self) -> dict: # Load balancing state # Service-level state "total_replicas": len(self._replicas), - "healthy_replica_count": sum(1 for r in self._replicas if r.healthy), + "healthy_replica_count": len(self._get_healthy_replicas()), "shutdown_requested": self._shutdown_requested, # Metrics summary "total_sessions": len(self._active_sessions), diff --git a/src/forge/types.py b/src/forge/types.py index cc41d2185..adb49c364 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -118,6 +118,8 @@ class ServiceConfig: health_poll_rate: float = 0.2 replica_max_concurrent_requests: int = 10 return_first_rank_result: bool = True + max_batch_size: int = 1 + batch_max_wait_s: float = 0.01 def to_process_config(self) -> ProcessConfig: """Extract ProcessConfig from this ServiceConfig. diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 3154dee88..7cb96667b 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Tests for service.py +Tests for service.py and router.py """ import asyncio @@ -14,7 +14,6 @@ import pytest from forge.controller import ForgeActor from forge.controller.service import ( - BatchRouter, LeastLoadedRouter, Replica, ReplicaState, @@ -787,176 +786,6 @@ async def test_session_router_with_round_robin_fallback(): assert r1.idx == r2.idx == 0 -@pytest.mark.asyncio -async def test_batching_router_batchsize_with_roundrobin(): - """Batch should flush when max batch size is reached using RoundRobinRouter.""" - replicas = [make_replica(0), make_replica(1)] - batch_size = 3 - - router = BatchRouter( - RoundRobinRouter(), - batch_max_size=batch_size, - batch_max_wait_s=0.5, # long enough to not trigger timeout - ) - - try: - # Enqueue `batch_size + 1` requests to force batch flush - tasks = [ - asyncio.create_task(router.get_replica(replicas)) - for _ in range(batch_size + 1) - ] - results = await asyncio.gather(*tasks) - - # Check all results are healthy replicas - assert all(r.state == ReplicaState.HEALTHY for r in results) - - # Check results only use existing replica indices - indices = {r.idx for r in results} - assert indices.issubset({0, 1}) - - # Ensure batch queue is empty after flush - assert router._queue.qsize() == 0 - finally: - router.shutdown() - - -@pytest.mark.asyncio -async def test_batching_router_skips_unhealthy_replicas(): - """If a replica becomes unhealthy before batch dispatch, it should be skipped.""" - replicas = [make_replica(0, load=0), make_replica(1, load=10)] - - router = BatchRouter( - LeastLoadedRouter(), - batch_max_size=4, - batch_max_wait_s=0.5, - ) - try: - # Start two requests that will form a batch - tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(2)] - - # While they are waiting, mark replica 0 (least loaded) as unhealthy - await asyncio.sleep(0.01) - replicas[0].state = ReplicaState.UNHEALTHY - - results = await asyncio.gather(*tasks) - - # All results must be the *healthy* replica (idx=1) - assert all(r.idx == 1 for r in results) - assert results[0].state == ReplicaState.HEALTHY - finally: - router.shutdown() - - -@pytest.mark.asyncio -async def test_batching_router_two_batches_timing(): - """Test that two sequential batches are processed independently with proper timing.""" - import time - - replicas = [make_replica(0, load=5), make_replica(1, load=10)] - batch_wait_time = 0.05 # 50ms timeout - - router = BatchRouter( - LeastLoadedRouter(), - batch_max_size=3, - batch_max_wait_s=batch_wait_time, - ) - try: - # First batch: 2 requests that will timeout - start_time = time.time() - - # Create first batch tasks - first_batch_tasks = [ - asyncio.create_task(router.get_replica(replicas)) for _ in range(2) - ] - - # Wait for first batch to complete (should timeout after batch_wait_time) - first_results = await asyncio.gather(*first_batch_tasks) - first_batch_duration = time.time() - start_time - - # Verify first batch took approximately the timeout duration (tighter tolerance) - assert ( - batch_wait_time <= first_batch_duration < batch_wait_time + 0.01 - ) # 10ms tolerance on 50ms timeout - - # Verify first batch results (should pick lowest load replica) - assert all(r.idx == 0 for r in first_results) # replica 0 has lower load - assert all(r.state == ReplicaState.HEALTHY for r in first_results) - - # Second batch: 2 more requests (new timing cycle should start) - second_batch_start = time.time() - - # Create second batch tasks - second_batch_tasks = [ - asyncio.create_task(router.get_replica(replicas)) for _ in range(2) - ] - - # Wait for second batch to complete - second_results = await asyncio.gather(*second_batch_tasks) - second_batch_duration = time.time() - second_batch_start - - # Verify second batch also took approximately the timeout duration (tighter tolerance) - assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01 - - # Verify second batch results - assert all(r.idx == 0 for r in second_results) # should still pick lowest load - assert all(r.state == ReplicaState.HEALTHY for r in second_results) - - # Ensure batch queue is empty after both batches - assert router._queue.qsize() == 0 - finally: - router.shutdown() - - -@pytest.mark.asyncio -async def test_batchrouter_callable_updates(): - """Test that callables reflect updates after a batch is processed.""" - - # Initial replicas and session map - replicas = [make_replica(0, load=0), make_replica(1, load=10)] - session_map = {} - - # Define dynamic callable for healthy replicas - def get_healthy_replicas(): - # Always return only replicas whose state is HEALTHY - return [r for r in replicas if r.healthy] - - # Wrap inner router to spy on session_map received - class SpyRouter(LeastLoadedRouter): - async def get_replica(self, healthy_replicas, sess_id, session_map_arg): - # Save the session_map passed in - self.last_session_map = session_map_arg - return await super().get_replica(healthy_replicas, sess_id, session_map_arg) - - # Router using callables - spy_inner_router = SpyRouter() - router = BatchRouter( - spy_inner_router, - batch_max_size=2, - batch_max_wait_s=0.05, - get_healthy_replicas=get_healthy_replicas, - session_map=session_map, - ) - - try: - # Mark replica 0 unhealthy *after* we enqueue the first batch - tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(1)] - await asyncio.sleep(0.001) - replicas[0].state = ReplicaState.UNHEALTHY - session_map["s1"] = 42 # simulate an update while batch is pending - - # Wait for batch to complete - results = await asyncio.gather(*tasks) - - # Verify router used healthy replica (idx=1) instead of least-loaded one (idx=0) - assert all(r.idx == 1 for r in results) - - # Verify router actually received the updated session_map - assert spy_inner_router.last_session_map["s1"] == 42 - - finally: - await router.shutdown() - - # Router integeration tests @@ -1017,3 +846,33 @@ async def test_session_router_assigns_and_updates_session_map_in_service(): finally: await service.shutdown() + + +@pytest.mark.asyncio +async def test_service_stateless_batch_flush_max_size(): + """Batch should flush when max batch size is reached using default router (RoundRobin).""" + # Create a service with 2 replicas and batching enabled + service = await Counter.options( + procs=1, num_replicas=2, max_batch_size=3, batch_max_wait_s=0.5 + ).as_service(2) + try: + # Enqueue batch_size + 1 requests to force a batch flush + tasks = [asyncio.create_task(service._get_replica(None)) for _ in range(4)] + results = await asyncio.gather(*tasks) + + # Check that results are only healthy replicas + assert all(r.healthy for r in results) + + replica_indices = {r.idx for r in results[:3]} # first batch of 3 + assert len(replica_indices) == 1 # all went to the same replica + + # Last request (4th) should be routed after timeout + last_request = results[3] + assert ( + last_request.idx not in replica_indices + ) # Last request (4th) should be assigned to a different replica (RoundRobin) + + # After flush, batch queue should be empty + assert service._batch_queue.qsize() == 0 + finally: + await service.stop() From 821714c727c2b557450ac0c1d83ea978c8dcfdd9 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 16:04:08 -0700 Subject: [PATCH 04/32] moving endpoint logic to endpoint.py; add decorator for service_endpoint --- src/forge/controller/service/__init__.py | 3 + src/forge/controller/service/endpoint.py | 262 ++++++++++++++++++++++ src/forge/controller/service/interface.py | 129 ++--------- src/forge/controller/service/router.py | 16 +- 4 files changed, 300 insertions(+), 110 deletions(-) create mode 100644 src/forge/controller/service/endpoint.py diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index f0d8fca7b..ed54d0046 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from .endpoint import BatchedServiceEndpoint, service_endpoint from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState @@ -24,4 +25,6 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", + "BatchedServiceEndpoint", + "service_endpoint", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py new file mode 100644 index 000000000..a50ccf778 --- /dev/null +++ b/src/forge/controller/service/endpoint.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Service endpoint management for the Forge framework. +""" + +import asyncio +from typing import Generic, List, TypeVar + +from monarch.actor import endpoint + +from typing_extensions import ParamSpec + +from .router import LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter + +P = ParamSpec("P") +R = TypeVar("R") + + +class ServiceEndpoint(Generic[P, R]): + """ + This extends Monarch's actor APIs for service endpoints. + - `route(*args, **kwargs)`: Routes the request to a single replica. + - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. + + Monarch's native actor APIs do not apply for services. + """ + + def __init__( + self, + service, + endpoint_name: str, + router: str = "round_robin", + batch_size: int = 1, + batch_timeout: float = 0.1, + ): + self.service = service + self.endpoint_name = endpoint_name + + self.router = self._resolve_router(router) + self.session_router = SessionRouter(fallback_router=self.router) + + def _resolve_router(self, router_name: str) -> Router: + """Convert a router name into a router object. + + Args: + router_name (str): a router name. Supported routers: "round_robin", "leastloaded". + + Returns: + Router: A Router object. + """ + if router_name == "round_robin": + return RoundRobinRouter() + if router_name == "leastloaded": + return LeastLoadedRouter() + raise ValueError( + f"Unknown router name: {router_name}. Supported routers: 'round_robin', 'leastloaded'." + ) + + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) + + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.service.call_all(self.endpoint_name, *args, **kwargs) + return result + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + +class BatchedServiceEndpoint(ServiceEndpoint[P, R]): + """ + A ServiceEndpoint that supports request batch routing. + + Args: + router: The underlying Router instance used to make routing decisions + session_router: The fallback Router for session-based routing. + batch_max_size: Maximum number of requests to collect in a single batch (default: 8) + batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01) + + Features: + - Maintains a batch queue + - Spawns a background task to group requests into batches + """ + + def __init__( + self, + service, + endpoint_name: str, + router: str = "round_robin", + session_router: str = "leastloaded", + batch_size: int = 1, + batch_timeout: float = 0.1, + ): + + super().__init__(service, endpoint_name) + + self.router = self._resolve_router(router) + self.session_router = SessionRouter( + fallback_router=self._resolve_router(session_router) + ) + + self.batch_size = batch_size + self.batch_timeout = batch_timeout + + self.batch_queue: asyncio.Queue = asyncio.Queue() + self.running_batch_loop = False + # if self.batch_size > 1: + # self.running_batch_loop = True + # self.batch_task = asyncio.create_task(self._batch_loop()) + + # async def _batch_loop(self): + # while self.running_batch_loop: + # batch_futs = [] + + # fut = await self.batch_queue.get() + # batch_futs.append(fut) + # start_time = time.monotonic() + + # while True: + # try: + # timeout = max( + # 0, self.batch_timeout - (time.monotonic() - start_time) + # ) + # fut = await asyncio.wait_for(self.batch_queue.get(), timeout) + # batch_futs.append(fut) + # if len(batch_futs) >= self.batch_size: + # break + # except asyncio.TimeoutError: + # break + + # healthy_replicas = [r for r in self.service._replicas if r.healthy] + # replica = self.router.get_replica(healthy_replicas) + + # for fut in batch_futs: + # fut.set_result(replica) + + # async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: + # sess_id = kwargs.pop("sess_id", None) + # if sess_id: + # healthy_replicas = [r for r in self.service._replicas if r.healthy] + # replica = self.session_router.get_replica(healthy_replicas, sess_id) + # else: + # if self.batch_size > 1: + # fut = asyncio.Future() + # self.batch_queue.put_nowait(fut) + # replica = await fut + # else: + # healthy_replicas = [r for r in self.service._replicas if r.healthy] + # replica = self.router.get_replica(healthy_replicas) + + # request = ServiceRequest( + # session_id=sess_id, + # function=self.endpoint_name, + # args=args, + # kwargs=kwargs, + # future=asyncio.Future(), + # ) + # await replica.enqueue_request(request) + # return await request.future + + +class ServiceEndpointV2(Generic[P, R]): + """An endpoint object specific to services. + + This loosely mimics the Endpoint APIs exposed in Monarch, with + a few key differences: + - Only choose and call are retained (dropping stream and call_one) + - Call returns a list directly rather than a ValueMesh. + + These changes are made with Forge use cases in mind, but can + certainly be expanded/adapted in the future. + + """ + + def __init__(self, actor_mesh, endpoint_name: str): + self.actor_mesh = actor_mesh + self.endpoint_name = endpoint_name + + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Chooses a replica to call based on context and load balancing strategy.""" + # Extract sess_id from kwargs if present + sess_id = kwargs.pop("sess_id", None) + return await self.actor_mesh.call.call_one( + sess_id, self.endpoint_name, *args, **kwargs + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + """Broadcasts a request to all healthy replicas and returns the results as a list.""" + result = await self.actor_mesh.call_all.call_one( + self.endpoint_name, *args, **kwargs + ) + return result + + +def service_endpoint( + *, + router="round_robin", + session_router="leastloaded", + batch_size=1, + batch_timeout=0.01, + propagate=None, + explicit_response_port=False, +): + """ + Marks an actor method as a service endpoint with batching routing support. + + Example: + class MyForgeActor(ForgeActor): + @service_endpoint(router="round_robin", batch_size=16, batch_timeout=0.05) + async def predict(self, x): ... + """ + + def decorator(method): + # First wrap in EndpointProperty (so actor has a proper endpoint) + ep = endpoint( + method, propagate=propagate, explicit_response_port=explicit_response_port + ) + ep._service_endpoint_config = dict( + router=router, + session_router=session_router, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + return ep + + return decorator diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 5b7e2f884..25e400dae 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -12,20 +12,16 @@ import contextvars import logging -from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Generic, List, ParamSpec, TypeVar from monarch._src.actor.endpoint import EndpointProperty -from .replica import Replica +from .endpoint import ServiceEndpoint, ServiceEndpointV2 + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -P = ParamSpec("P") -R = TypeVar("R") - @dataclass class Session: @@ -77,94 +73,6 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): self.session_id = None -class ServiceEndpoint(Generic[P, R]): - """ - This extends Monarch's actor APIs for service endpoints. - - `route(*args, **kwargs)`: Routes the request to a single replica. - - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. - - Monarch's native actor APIs do not apply for services. - """ - - def __init__(self, service, endpoint_name: str): - self.service = service - self.endpoint_name = endpoint_name - - async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) - - async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.service.call_all(self.endpoint_name, *args, **kwargs) - return result - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use choose() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use call() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: - raise NotImplementedError( - "You tried to use a call_one() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - raise NotImplementedError( - "You tried to use broadcast() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - async def generate(self, *args: P.args, **kwargs: P.kwargs): - raise NotImplementedError( - "You tried to use generate() on a service, not an actor. " - "Services only support route() and fanout()." - ) - - -class ServiceEndpointV2(Generic[P, R]): - """An endpoint object specific to services. - - This loosely mimics the Endpoint APIs exposed in Monarch, with - a few key differences: - - Only choose and call are retained (dropping stream and call_one) - - Call returns a list directly rather than a ValueMesh. - - These changes are made with Forge use cases in mind, but can - certainly be expanded/adapted in the future. - - """ - - def __init__(self, actor_mesh, endpoint_name: str): - self.actor_mesh = actor_mesh - self.endpoint_name = endpoint_name - - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" - # Extract sess_id from kwargs if present - sess_id = kwargs.pop("sess_id", None) - return await self.actor_mesh.call.call_one( - sess_id, self.endpoint_name, *args, **kwargs - ) - - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: - """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.actor_mesh.call_all.call_one( - self.endpoint_name, *args, **kwargs - ) - return result - - class ServiceInterface: """ A lightweight interface to the base Service class. @@ -183,9 +91,26 @@ def __init__(self, _service, actor_def): for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) if isinstance(attr_value, EndpointProperty): + # Decorated with @endpoint # Create a ServiceEndpoint that will route through the Service Actor endpoint = ServiceEndpoint(self._service, attr_name) - setattr(self, attr_name, endpoint) + elif hasattr(attr_value, "_service_endpoint_config"): + print("reached here") + # Decorated with @service_endpoint + # Create a ServiceEndpoint with batch routing config + cfg = attr_value._service_endpoint_config + endpoint = ServiceEndpoint( + self._service, + attr_name, + router=cfg["router"], + batch_size=cfg["batch_size"], + batch_timeout=cfg["batch_timeout"], + ) + print("reached here. cfg: ", cfg) + else: + # Not decorated with @endpoint or @service_endpoint + continue + setattr(self, attr_name, endpoint) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: @@ -306,17 +231,3 @@ def __getattr__(self, name: str): raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) - - -class Router(ABC): - """Abstract base class for routing logic.""" - - @abstractmethod - def get_replica( - self, - healthy_replicas: List[Replica], - sess_id: str | None = None, - session_map: Dict[str, int] | None = None, - ) -> Replica: - """Select a replica from the list based on routing logic.""" - pass diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 1dd0308e0..6aba73533 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -6,15 +6,29 @@ import logging +from abc import ABC, abstractmethod from typing import Dict, List -from .interface import Router from .replica import Replica logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +class Router(ABC): + """Abstract base class for routing logic.""" + + @abstractmethod + def get_replica( + self, + healthy_replicas: List[Replica], + sess_id: str | None = None, + session_map: Dict[str, int] | None = None, + ) -> Replica: + """Select a replica from the list based on routing logic.""" + pass + + class RoundRobinRouter(Router): """Round-robin router for stateless requests.""" From 64c70764d1431eb0a0c023e85be822c65a7f2037 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 17:55:05 -0700 Subject: [PATCH 05/32] buggy version --- src/forge/controller/service/endpoint.py | 107 +++++++++++- src/forge/controller/service/interface.py | 6 +- src/forge/controller/service/service.py | 169 ++++++++++--------- tests/unit_tests/test_router.py | 193 ++++++++++++++++++++++ tests/unit_tests/test_service.py | 149 +---------------- 5 files changed, 399 insertions(+), 225 deletions(-) create mode 100644 tests/unit_tests/test_router.py diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index a50ccf778..16d8b667b 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -9,12 +9,14 @@ """ import asyncio +import time from typing import Generic, List, TypeVar from monarch.actor import endpoint - from typing_extensions import ParamSpec +from .replica import Replica + from .router import LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter P = ParamSpec("P") @@ -44,6 +46,16 @@ def __init__( self.router = self._resolve_router(router) self.session_router = SessionRouter(fallback_router=self.router) + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self._running_batch_loop = False + self._batch_queue: asyncio.Queue = asyncio.Queue() + if self.batch_size > 1: + self._running_batch_loop = True + self.batch_task = asyncio.create_task(self._batch_loop()) + + self.max_attempts = 1 # number of tries for routing = initial + retries + def _resolve_router(self, router_name: str) -> Router: """Convert a router name into a router object. @@ -65,7 +77,94 @@ async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: """Chooses a replica to call based on context and load balancing strategy.""" # Extract sess_id from kwargs if present sess_id = kwargs.pop("sess_id", None) - return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) + + for attempt in range(self.max_attempts): + replica = await self._choose_replica(sess_id) + + # Wait for the result + try: + return await self.service._call( + replica, sess_id, self.endpoint_name, *args, **kwargs + ) + except Exception as e: + # If the replica failed, try to retry + if not replica.healthy and attempt < self.max_attempts - 1: + # Clear sticky mapping before retry + if ( + sess_id is not None + and sess_id in self.service._session_replica_map + ): + del self.service._session_replica_map[sess_id] + continue # retry with a fresh replica + raise + + async def _choose_replica(self, sess_id: str | None) -> "Replica": + """Get a replica for the given session ID.""" + + # Stateful routing always uses session router + if sess_id: + healthy = self.service._get_healthy_replicas() + return self.session_router.get_replica( + healthy, sess_id, self.service._session_replica_map + ) + # Stateless: batching + if self.batch_size > 1: + fut = asyncio.Future() + self._batch_queue.put_nowait(fut) + return await fut + + # No batching, pick immediately + healthy = self.service._get_healthy_replicas() + return self.router.get_replica(healthy) + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_size or batch_timeout is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running_batch_loop: + batch_futs = [] + + # Wait for first request + fut = await self._batch_queue.get() + batch_futs.append(fut) + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self.batch_timeout - (time.monotonic() - start_time) + ) + fut = await asyncio.wait_for( + self._batch_queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch_futs.append(fut) + + if len(batch_futs) >= self.batch_size: + break + except asyncio.TimeoutError: + break + + healthy_replicas = self.service._get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.router.get_replica( + healthy_replicas, None, self.service._session_replica_map + ) + + # Fulfill all futures with the chosen replica + for fut in batch_futs: + fut.set_result(replica) async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" @@ -102,6 +201,10 @@ async def generate(self, *args: P.args, **kwargs: P.kwargs): "Services only support route() and fanout()." ) + async def stop(self): + """Stop the batching loop.""" + self._running_batch_loop = False + class BatchedServiceEndpoint(ServiceEndpoint[P, R]): """ diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 25e400dae..444e93f33 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -123,8 +123,12 @@ async def terminate_session(self, sess_id: str): async def shutdown(self) -> None: """ - Shut down the underlying Service. + Shut down the underlying Service and all endpoints. """ + for attr in dir(self): + ep = getattr(self, attr) + if isinstance(ep, ServiceEndpoint): + await ep.stop() await self._service.stop() def session(self) -> "SessionContext": diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 327ffda3d..65cd6f8c1 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -39,8 +39,6 @@ import uuid from typing import Dict, List -from monarch.actor import Actor, endpoint - from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics @@ -53,6 +51,8 @@ ) from forge.types import ServiceConfig +from monarch.actor import Actor, endpoint + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -200,39 +200,11 @@ async def _batch_loop(self): for fut in batch_futs: fut.set_result(replica) - async def _call(self, sess_id: str | None, function: str, *args, **kwargs): - """ - Routes a function call to the appropriate replica with load balancing and fault tolerance. - - This is the core routing method that handles: - - Session-based routing for stateful calls - - Round-robin load balancing for stateless calls - - Custom routing based on context hints - - Automatic retry on replica failures - - Request queuing and processing - - Args: - sess_id: Optional session ID for stateful routing - function: Name of the actor endpoint to call - *args: Positional arguments to pass to the endpoint - **kwargs: Keyword arguments to pass to the endpoint - - Returns: - The result from the actor endpoint execution - - Raises: - RuntimeError: If no healthy replicas are available - Exception: Any exception raised by the actor endpoint - """ - # Check context variables for session state if no explicit sess_id - if sess_id is None: - ctx = _session_context.get(None) - if ctx: - sess_id = ctx["session_id"] - - replica = await self._get_replica(sess_id) + async def _call( + self, replica: "Replica", sess_id: str | None, function: str, *args, **kwargs + ): + """Send request directly to a chosen replica and wait for result.""" - # Create a ServiceRequest object to queue request = ServiceRequest( session_id=sess_id, function=function, @@ -244,19 +216,65 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): # Queue the request using replica's method await replica.enqueue_request(request) - # Wait for the result - try: - return await request.future - except Exception as e: - # If the replica failed, try to retry once - if not replica.healthy: - logger.debug( - f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" - ) - return await self._retry_request_on_healthy_replica( - sess_id, function, *args, **kwargs - ) - raise + return await request.future + + # async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + # """ + # Routes a function call to the appropriate replica with load balancing and fault tolerance. + + # This is the core routing method that handles: + # - Session-based routing for stateful calls + # - Round-robin load balancing for stateless calls + # - Custom routing based on context hints + # - Automatic retry on replica failures + # - Request queuing and processing + + # Args: + # sess_id: Optional session ID for stateful routing + # function: Name of the actor endpoint to call + # *args: Positional arguments to pass to the endpoint + # **kwargs: Keyword arguments to pass to the endpoint + + # Returns: + # The result from the actor endpoint execution + + # Raises: + # RuntimeError: If no healthy replicas are available + # Exception: Any exception raised by the actor endpoint + # """ + # # Check context variables for session state if no explicit sess_id + # if sess_id is None: + # ctx = _session_context.get(None) + # if ctx: + # sess_id = ctx["session_id"] + + # replica = await self._get_replica(sess_id) + + # # Create a ServiceRequest object to queue + # request = ServiceRequest( + # session_id=sess_id, + # function=function, + # args=args, + # kwargs=kwargs, + # future=asyncio.Future(), + # ) + + # # Queue the request using replica's method + # await replica.enqueue_request(request) + + # # Wait for the result + # try: + # return await request.future + # except Exception as e: + # # If the replica failed, try to retry once + # if not replica.healthy: + # logger.debug( + # f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" + # ) + # return await self._retry_request_on_healthy_replica( + # sess_id, function, *args, **kwargs + # ) + # raise async def call_all(self, function: str, *args, **kwargs) -> List: """ @@ -309,16 +327,16 @@ async def call_all(self, function: str, *args, **kwargs) -> List: return results - async def _retry_request_on_healthy_replica( - self, sess_id: str | None, function: str, *args, **kwargs - ): - """Retries a failed request on a healthy replica.""" - # Force reassignment to a healthy replica (only for session-based calls) - if sess_id is not None and sess_id in self._session_replica_map: - del self._session_replica_map[sess_id] + # async def _retry_request_on_healthy_replica( + # self, sess_id: str | None, function: str, *args, **kwargs + # ): + # """Retries a failed request on a healthy replica.""" + # # Force reassignment to a healthy replica (only for session-based calls) + # if sess_id is not None and sess_id in self._session_replica_map: + # del self._session_replica_map[sess_id] - # Retry the call (this will assign to a new healthy replica) - return await self._call(sess_id, function, *args, **kwargs) + # # Retry the call (this will assign to a new healthy replica) + # return await self._call(sess_id, function, *args, **kwargs) async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" @@ -538,26 +556,25 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _get_replica(self, sess_id: str | None) -> "Replica": - """Get a replica for the given session ID.""" - - if sess_id: - # Stateful routing always uses session router - healthy_replicas = self._get_healthy_replicas() - return self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - - # Stateless: batching - if self._max_batch_size > 1: - fut = asyncio.Future() - healthy_replicas = self._get_healthy_replicas() - self._batch_queue.put_nowait(fut) - return await fut - else: - # No batching, pick immediately - healthy_replicas = self._get_healthy_replicas() - return self._default_router.get_replica(healthy_replicas) + # async def _get_replica(self, sess_id: str | None) -> "Replica": + # """Get a replica for the given session ID.""" + + # if sess_id: + # # Stateful routing always uses session router + # healthy_replicas = self._get_healthy_replicas() + # return self._session_router.get_replica( + # healthy_replicas, sess_id, self._session_replica_map + # ) + + # # Stateless: batching + # if self._max_batch_size > 1: + # fut = asyncio.Future() + # self._batch_queue.put_nowait(fut) + # return await fut + # else: + # # No batching, pick immediately + # healthy_replicas = self._get_healthy_replicas() + # return self._default_router.get_replica(healthy_replicas) async def stop(self): logger.debug("Stopping service...") diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py new file mode 100644 index 000000000..0463d9a12 --- /dev/null +++ b/tests/unit_tests/test_router.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +""" +Tests for router.py and batch routing in ServiceEndpoint +""" + +import asyncio +import contextlib +import logging + +import pytest +from forge.controller import ForgeActor +from forge.controller.service import ( + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + service_endpoint, + SessionRouter, +) + +from forge.types import ProcessConfig +from monarch.actor import endpoint + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Counter(ForgeActor): + """Test actor that maintains a counter with various endpoints.""" + + def __init__(self, v: int): + self.v = v + + @endpoint + async def incr(self): + """Increment the counter.""" + self.v += 1 + + @endpoint + async def value(self) -> int: + """Get the current counter value.""" + return self.v + + @endpoint + async def fail_me(self): + """Endpoint that always fails to test error handling.""" + raise RuntimeError("I was asked to fail") + + @endpoint + async def add_to_value(self, amount: int, multiplier: int = 1) -> int: + """Add an amount (optionally multiplied) to the current value.""" + logger.info(f"adding {amount} with {multiplier}") + self.v += amount * multiplier + return self.v + + @service_endpoint(router="round_robin", batch_size=3, batch_timeout=0.1) + async def rr_incr(self): + """Increment using RoundRobin router.""" + self.v += 1 + + +def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: + """Helper to build a replica with specified state and load.""" + replica = Replica( + idx=idx, + proc_config=ProcessConfig(), + actor_def=Counter, + actor_args=(), + actor_kwargs={}, + ) + replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY + replica.active_requests = load + return replica + + +# Router Tests + + +@pytest.mark.asyncio +async def test_session_router_fallback_rr_vs_ll(): + """Switch fallback router to round-robin and verify assignment order.""" + # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = RoundRobinRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx != r2.idx + assert set(session_map.values()) == {0, 1} + + # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas + replicas = [make_replica(0, load=0), make_replica(1, load=5)] + session_map = {} + fallback = LeastLoadedRouter() + router = SessionRouter(fallback) + + r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) + r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) + + assert r1.idx == r2.idx == 0 + + +# Router integeration tests + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + for _ in range(6): + await service.rr_incr.route() + values = await service.value.fanout() + print(values) + results.append(values) + print("results: ", results) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + final_values = results[-1] # last snapshot + assert sorted(final_values) == [2, 2, 2] + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_session_router_assigns_and_updates_session_map_in_service(): + """Integration: Service with SessionRouter preserves sticky sessions.""" + service = await Counter.options( + procs=1, + num_replicas=2, + ).as_service(v=0) + + try: + # First call with sess_id -> assign a replica + await service.incr.route(sess_id="sess1") + values1 = await service.value.fanout() + + # Second call with same sess_id -> must hit same replica + await service.incr.route(sess_id="sess1") + values2 = await service.value.fanout() + + # Difference should only be on one replica (sticky session) + diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] + assert ( + sum(diffs) == 1 + ), f"Expected exactly one replica to increment, got {diffs}" + assert max(diffs) == 1 and min(diffs) == 0 + + # Session map in service should reflect assigned replica + assigned_idx = service._session_replica_map["sess1"] + assert values2[assigned_idx] == values1[assigned_idx] + 1 + + finally: + await service.shutdown() + + +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_service_endpoint_batch_flush_max_size(): + """Ensure @service_endpoint batching flushes correctly when max batch size reached.""" + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # Make 3 concurrent requests (batch_size = 3) + tasks = [asyncio.create_task(service.rr_incr.route()) for _ in range(4)] + await asyncio.gather(*tasks) + + values = await service.value.fanout() + + # Expectation: + # - 3 increments batched together on one replica + # - 1 increment on the other replica (new batch after flush) + assert sum(values) == 3, f"Expected total=3, got {values}" + + # Exactly one replica should have count=3, and the other count=1 + assert sorted(values) == [1, 3], f"Expected [1, 2], got {values}" + + finally: + await service.shutdown() diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 7cb96667b..64882d243 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Tests for service.py and router.py +Tests for service.py """ import asyncio @@ -13,15 +13,8 @@ import pytest from forge.controller import ForgeActor -from forge.controller.service import ( - LeastLoadedRouter, - Replica, - ReplicaState, - RoundRobinRouter, - ServiceConfig, - SessionRouter, -) -from forge.types import ProcessConfig +from forge.controller.service import ServiceConfig + from monarch.actor import Actor, endpoint logger = logging.getLogger(__name__) @@ -63,20 +56,6 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: return self.v -def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: - """Helper to build a replica with specified state and load.""" - replica = Replica( - idx=idx, - proc_config=ProcessConfig(), - actor_def=Counter, - actor_args=(), - actor_kwargs={}, - ) - replica.state = ReplicaState.HEALTHY if healthy else ReplicaState.UNHEALTHY - replica.active_requests = load - return replica - - # Actor Tests @@ -754,125 +733,3 @@ async def test_broadcast_fanout_vs_route(): finally: await service.shutdown() - - -# Router Tests - - -@pytest.mark.asyncio -async def test_session_router_with_round_robin_fallback(): - """Switch fallback router to round-robin and verify assignment order.""" - # Choose RoundRobinRouter as fallback, r1 and r2 should be assigned to different replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = RoundRobinRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx != r2.idx - assert set(session_map.values()) == {0, 1} - - # If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas - replicas = [make_replica(0, load=0), make_replica(1, load=5)] - session_map = {} - fallback = LeastLoadedRouter() - router = SessionRouter(fallback) - - r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map) - r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map) - - assert r1.idx == r2.idx == 0 - - -# Router integeration tests - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_round_robin_router_distribution(): - """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" - service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) - - try: - # Make multiple sessionless calls using route() - results = [] - for _ in range(6): - await service.incr.route() - values = await service.value.fanout() - print(values) - results.append(values) - print("results: ", results) - # Verify that requests were distributed round-robin - # Each call increments a single replica, so after 6 calls we expect: - # 2 increments per replica (since 3 replicas, 6 calls) - final_values = results[-1] # last snapshot - assert sorted(final_values) == [2, 2, 2] - - finally: - await service.shutdown() - - -@pytest.mark.timeout(10) -@pytest.mark.asyncio -async def test_session_router_assigns_and_updates_session_map_in_service(): - """Integration: Service with SessionRouter preserves sticky sessions.""" - # Use LeastLoaded as default, SessionRouter (with fallback) is always active - service = await Counter.options( - procs=1, - num_replicas=2, - ).as_service(v=0) - - try: - # First call with sess_id -> assign a replica - await service.incr.route(sess_id="sess1") - values1 = await service.value.fanout() - - # Second call with same sess_id -> must hit same replica - await service.incr.route(sess_id="sess1") - values2 = await service.value.fanout() - - # Difference should only be on one replica (sticky session) - diffs = [v2 - v1 for v1, v2 in zip(values1, values2)] - assert ( - sum(diffs) == 1 - ), f"Expected exactly one replica to increment, got {diffs}" - assert max(diffs) == 1 and min(diffs) == 0 - - # Session map in service should reflect assigned replica - assigned_idx = service._session_replica_map["sess1"] - assert values2[assigned_idx] == values1[assigned_idx] + 1 - - finally: - await service.shutdown() - - -@pytest.mark.asyncio -async def test_service_stateless_batch_flush_max_size(): - """Batch should flush when max batch size is reached using default router (RoundRobin).""" - # Create a service with 2 replicas and batching enabled - service = await Counter.options( - procs=1, num_replicas=2, max_batch_size=3, batch_max_wait_s=0.5 - ).as_service(2) - try: - # Enqueue batch_size + 1 requests to force a batch flush - tasks = [asyncio.create_task(service._get_replica(None)) for _ in range(4)] - results = await asyncio.gather(*tasks) - - # Check that results are only healthy replicas - assert all(r.healthy for r in results) - - replica_indices = {r.idx for r in results[:3]} # first batch of 3 - assert len(replica_indices) == 1 # all went to the same replica - - # Last request (4th) should be routed after timeout - last_request = results[3] - assert ( - last_request.idx not in replica_indices - ) # Last request (4th) should be assigned to a different replica (RoundRobin) - - # After flush, batch queue should be empty - assert service._batch_queue.qsize() == 0 - finally: - await service.stop() From c581a9a6ce026bb412c61941ad9dec97902288eb Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 18:30:43 -0700 Subject: [PATCH 06/32] finally working, todo: clean up and add docstr --- src/forge/controller/service/interface.py | 19 +++++++++----- tests/unit_tests/test_router.py | 32 ++++++++++++++++++----- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 444e93f33..372b0ad29 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -90,12 +90,10 @@ def __init__(self, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Decorated with @endpoint - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpoint(self._service, attr_name) - elif hasattr(attr_value, "_service_endpoint_config"): - print("reached here") + # logger.info(f"Endpoint name: {attr_name}") + + if hasattr(attr_value, "_service_endpoint_config"): + # logger.info("reached here") # Decorated with @service_endpoint # Create a ServiceEndpoint with batch routing config cfg = attr_value._service_endpoint_config @@ -106,7 +104,14 @@ def __init__(self, _service, actor_def): batch_size=cfg["batch_size"], batch_timeout=cfg["batch_timeout"], ) - print("reached here. cfg: ", cfg) + # logger.info("reached here. cfg: ", cfg) + + elif isinstance(attr_value, EndpointProperty): + # logger.info(f"EndpointProperty name: {attr_name}") + + # Decorated with @endpoint + # Create a ServiceEndpoint that will route through the Service Actor + endpoint = ServiceEndpoint(self._service, attr_name) else: # Not decorated with @endpoint or @service_endpoint continue diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 0463d9a12..c12707804 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -57,7 +57,7 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: self.v += amount * multiplier return self.v - @service_endpoint(router="round_robin", batch_size=3, batch_timeout=0.1) + @service_endpoint(router="round_robin", batch_size=3, batch_timeout=1) async def rr_incr(self): """Increment using RoundRobin router.""" self.v += 1 @@ -120,11 +120,9 @@ async def test_round_robin_router_distribution(): # Make multiple sessionless calls using route() results = [] for _ in range(6): - await service.rr_incr.route() + await service.incr.route() values = await service.value.fanout() - print(values) results.append(values) - print("results: ", results) # Verify that requests were distributed round-robin # Each call increments a single replica, so after 6 calls we expect: # 2 increments per replica (since 3 replicas, 6 calls) @@ -135,6 +133,28 @@ async def test_round_robin_router_distribution(): await service.shutdown() +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_round_robin_router_distribution2(): + # TODO: change name + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" + service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) + + try: + # Make multiple sessionless calls using route() + results = [] + tasks = [service.rr_incr.route() for _ in range(6)] + await asyncio.gather(*tasks) + # Verify that requests were distributed round-robin + # Each call increments a single replica, so after 6 calls we expect: + # 2 increments per replica (since 3 replicas, 6 calls) + values = await service.value.fanout() + assert sorted(values) == [0, 3, 3] + + finally: + await service.shutdown() + + @pytest.mark.timeout(10) @pytest.mark.asyncio async def test_session_router_assigns_and_updates_session_map_in_service(): @@ -184,10 +204,10 @@ async def test_service_endpoint_batch_flush_max_size(): # Expectation: # - 3 increments batched together on one replica # - 1 increment on the other replica (new batch after flush) - assert sum(values) == 3, f"Expected total=3, got {values}" + assert sum(values) == 4, f"Expected total=4, got {values}" # Exactly one replica should have count=3, and the other count=1 - assert sorted(values) == [1, 3], f"Expected [1, 2], got {values}" + assert sorted(values) == [1, 3], f"Expected [1, 3], got {values}" finally: await service.shutdown() From 2f87cb16a73ee537b57fc4d0f66b2319a783573e Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 20:23:45 -0700 Subject: [PATCH 07/32] fix lint and clean up --- src/forge/controller/service/endpoint.py | 92 ----------------------- src/forge/controller/service/interface.py | 6 -- src/forge/controller/service/service.py | 4 +- tests/unit_tests/test_router.py | 6 +- 4 files changed, 4 insertions(+), 104 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 16d8b667b..623e80cc2 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -206,98 +206,6 @@ async def stop(self): self._running_batch_loop = False -class BatchedServiceEndpoint(ServiceEndpoint[P, R]): - """ - A ServiceEndpoint that supports request batch routing. - - Args: - router: The underlying Router instance used to make routing decisions - session_router: The fallback Router for session-based routing. - batch_max_size: Maximum number of requests to collect in a single batch (default: 8) - batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01) - - Features: - - Maintains a batch queue - - Spawns a background task to group requests into batches - """ - - def __init__( - self, - service, - endpoint_name: str, - router: str = "round_robin", - session_router: str = "leastloaded", - batch_size: int = 1, - batch_timeout: float = 0.1, - ): - - super().__init__(service, endpoint_name) - - self.router = self._resolve_router(router) - self.session_router = SessionRouter( - fallback_router=self._resolve_router(session_router) - ) - - self.batch_size = batch_size - self.batch_timeout = batch_timeout - - self.batch_queue: asyncio.Queue = asyncio.Queue() - self.running_batch_loop = False - # if self.batch_size > 1: - # self.running_batch_loop = True - # self.batch_task = asyncio.create_task(self._batch_loop()) - - # async def _batch_loop(self): - # while self.running_batch_loop: - # batch_futs = [] - - # fut = await self.batch_queue.get() - # batch_futs.append(fut) - # start_time = time.monotonic() - - # while True: - # try: - # timeout = max( - # 0, self.batch_timeout - (time.monotonic() - start_time) - # ) - # fut = await asyncio.wait_for(self.batch_queue.get(), timeout) - # batch_futs.append(fut) - # if len(batch_futs) >= self.batch_size: - # break - # except asyncio.TimeoutError: - # break - - # healthy_replicas = [r for r in self.service._replicas if r.healthy] - # replica = self.router.get_replica(healthy_replicas) - - # for fut in batch_futs: - # fut.set_result(replica) - - # async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - # sess_id = kwargs.pop("sess_id", None) - # if sess_id: - # healthy_replicas = [r for r in self.service._replicas if r.healthy] - # replica = self.session_router.get_replica(healthy_replicas, sess_id) - # else: - # if self.batch_size > 1: - # fut = asyncio.Future() - # self.batch_queue.put_nowait(fut) - # replica = await fut - # else: - # healthy_replicas = [r for r in self.service._replicas if r.healthy] - # replica = self.router.get_replica(healthy_replicas) - - # request = ServiceRequest( - # session_id=sess_id, - # function=self.endpoint_name, - # args=args, - # kwargs=kwargs, - # future=asyncio.Future(), - # ) - # await replica.enqueue_request(request) - # return await request.future - - class ServiceEndpointV2(Generic[P, R]): """An endpoint object specific to services. diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 372b0ad29..df40bb803 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -90,10 +90,7 @@ def __init__(self, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - # logger.info(f"Endpoint name: {attr_name}") - if hasattr(attr_value, "_service_endpoint_config"): - # logger.info("reached here") # Decorated with @service_endpoint # Create a ServiceEndpoint with batch routing config cfg = attr_value._service_endpoint_config @@ -104,11 +101,8 @@ def __init__(self, _service, actor_def): batch_size=cfg["batch_size"], batch_timeout=cfg["batch_timeout"], ) - # logger.info("reached here. cfg: ", cfg) elif isinstance(attr_value, EndpointProperty): - # logger.info(f"EndpointProperty name: {attr_name}") - # Decorated with @endpoint # Create a ServiceEndpoint that will route through the Service Actor endpoint = ServiceEndpoint(self._service, attr_name) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 65cd6f8c1..44712c090 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -39,6 +39,8 @@ import uuid from typing import Dict, List +from monarch.actor import Actor, endpoint + from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics @@ -51,8 +53,6 @@ ) from forge.types import ServiceConfig -from monarch.actor import Actor, endpoint - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index c12707804..861bdf271 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -8,7 +8,6 @@ """ import asyncio -import contextlib import logging import pytest @@ -135,9 +134,8 @@ async def test_round_robin_router_distribution(): @pytest.mark.timeout(10) @pytest.mark.asyncio -async def test_round_robin_router_distribution2(): - # TODO: change name - """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas.""" +async def test_round_robin_router_distribution_with_batching(): + """Test that the RoundRobinRouter distributes sessionless calls evenly across replicas with batch routing.""" service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) try: From 52796d1c89ab06e2f84027ffd85a3201e3a9b0e9 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 20:29:26 -0700 Subject: [PATCH 08/32] more clean up --- src/forge/controller/service/__init__.py | 2 +- src/forge/controller/service/endpoint.py | 28 ++++ src/forge/controller/service/interface.py | 9 +- src/forge/controller/service/service.py | 151 ---------------------- src/forge/types.py | 2 - 5 files changed, 32 insertions(+), 160 deletions(-) diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index ed54d0046..895437b19 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -25,6 +25,6 @@ "LeastLoadedRouter", "RoundRobinRouter", "SessionRouter", - "BatchedServiceEndpoint", "service_endpoint", + "BatchedServiceEndpoint", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 623e80cc2..b74a6af6c 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -206,6 +206,34 @@ async def stop(self): self._running_batch_loop = False +class BatchedServiceEndpoint(ServiceEndpoint[P, R]): + """ + A ServiceEndpoint that supports request batch routing. + + Args: + router: The underlying Router instance used to make routing decisions + session_router: The fallback Router for session-based routing. + batch_max_size: Maximum number of requests to collect in a single batch (default: 8) + batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01) + + Features: + - Maintains a batch queue + - Spawns a background task to group requests into batches + """ + + def __init__( + self, + service, + endpoint_name: str, + router: str = "round_robin", + session_router: str = "leastloaded", + batch_size: int = 1, + batch_timeout: float = 0.1, + ): + + super().__init__(service, endpoint_name) + + class ServiceEndpointV2(Generic[P, R]): """An endpoint object specific to services. diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index df40bb803..738882ef8 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -11,7 +11,6 @@ """ import contextvars -import logging from dataclasses import dataclass from monarch._src.actor.endpoint import EndpointProperty @@ -19,10 +18,6 @@ from .endpoint import ServiceEndpoint, ServiceEndpointV2 -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - @dataclass class Session: """Simple session data holder.""" @@ -108,7 +103,9 @@ def __init__(self, _service, actor_def): endpoint = ServiceEndpoint(self._service, attr_name) else: # Not decorated with @endpoint or @service_endpoint - continue + raise ValueError( + f"Attribute '{attr_name}' is not a valid service endpoint" + ) setattr(self, attr_name, endpoint) # Session management methods - handled by ServiceInterface diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 44712c090..be024ffad 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -35,7 +35,6 @@ import asyncio import logging import pprint -import time import uuid from typing import Dict, List @@ -111,13 +110,6 @@ async def __initialize__(self): self._default_router = RoundRobinRouter() self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) - # Batching - self._max_batch_size = self._cfg.max_batch_size - self._batch_max_wait_s = self._cfg.batch_max_wait_s - self._batch_task: asyncio.Task | None = None - self._running_batch_loop = False - self._batch_queue: asyncio.Queue = asyncio.Queue() - # Initialize all replicas replicas = [] num_replicas = self._cfg.num_replicas @@ -146,60 +138,6 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - # Start batch loop if batching enabled - if self._max_batch_size > 1: - self._running_batch_loop = True - self._batch_task = asyncio.create_task(self._batch_loop()) - - async def _batch_loop(self): - """Background task that continuously processes batches of routing requests. - - This is the core batching logic that runs in a separate asyncio task. - It collects requests from the queue and processes them in batches based - on size and time constraints. - - The loop follows these steps: - 1. Wait for the first request to start a new batch - 2. Collect additional requests until batch_max_size or batch_max_wait_s is reached - 3. Make a single routing decision for the entire batch - 4. Fulfill all futures with the selected replica - - This process repeats indefinitely until the task is cancelled. - """ - while self._running_batch_loop: - batch_futs = [] - - # Wait for first request - fut = await self._batch_queue.get() - batch_futs.append(fut) - start_time = time.monotonic() - - while True: - try: - timeout = max( - 0, self._batch_max_wait_s - (time.monotonic() - start_time) - ) - fut = await asyncio.wait_for( - self._batch_queue.get(), timeout - ) # wait for timeout or until self._queue.get() finishes - batch_futs.append(fut) - - if len(batch_futs) >= self._max_batch_size: - break - except asyncio.TimeoutError: - break - - healthy_replicas = self._get_healthy_replicas() - - # One routing decision for the whole batch - replica = self._default_router.get_replica( - healthy_replicas, None, self._session_replica_map - ) - - # Fulfill all futures with the chosen replica - for fut in batch_futs: - fut.set_result(replica) - async def _call( self, replica: "Replica", sess_id: str | None, function: str, *args, **kwargs ): @@ -218,64 +156,6 @@ async def _call( return await request.future - # async def _call(self, sess_id: str | None, function: str, *args, **kwargs): - # """ - # Routes a function call to the appropriate replica with load balancing and fault tolerance. - - # This is the core routing method that handles: - # - Session-based routing for stateful calls - # - Round-robin load balancing for stateless calls - # - Custom routing based on context hints - # - Automatic retry on replica failures - # - Request queuing and processing - - # Args: - # sess_id: Optional session ID for stateful routing - # function: Name of the actor endpoint to call - # *args: Positional arguments to pass to the endpoint - # **kwargs: Keyword arguments to pass to the endpoint - - # Returns: - # The result from the actor endpoint execution - - # Raises: - # RuntimeError: If no healthy replicas are available - # Exception: Any exception raised by the actor endpoint - # """ - # # Check context variables for session state if no explicit sess_id - # if sess_id is None: - # ctx = _session_context.get(None) - # if ctx: - # sess_id = ctx["session_id"] - - # replica = await self._get_replica(sess_id) - - # # Create a ServiceRequest object to queue - # request = ServiceRequest( - # session_id=sess_id, - # function=function, - # args=args, - # kwargs=kwargs, - # future=asyncio.Future(), - # ) - - # # Queue the request using replica's method - # await replica.enqueue_request(request) - - # # Wait for the result - # try: - # return await request.future - # except Exception as e: - # # If the replica failed, try to retry once - # if not replica.healthy: - # logger.debug( - # f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" - # ) - # return await self._retry_request_on_healthy_replica( - # sess_id, function, *args, **kwargs - # ) - # raise - async def call_all(self, function: str, *args, **kwargs) -> List: """ Broadcasts a function call to all healthy replicas and returns results as a list. @@ -327,17 +207,6 @@ async def call_all(self, function: str, *args, **kwargs) -> List: return results - # async def _retry_request_on_healthy_replica( - # self, sess_id: str | None, function: str, *args, **kwargs - # ): - # """Retries a failed request on a healthy replica.""" - # # Force reassignment to a healthy replica (only for session-based calls) - # if sess_id is not None and sess_id in self._session_replica_map: - # del self._session_replica_map[sess_id] - - # # Retry the call (this will assign to a new healthy replica) - # return await self._call(sess_id, function, *args, **kwargs) - async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" migrated_requests = [] @@ -556,26 +425,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - # async def _get_replica(self, sess_id: str | None) -> "Replica": - # """Get a replica for the given session ID.""" - - # if sess_id: - # # Stateful routing always uses session router - # healthy_replicas = self._get_healthy_replicas() - # return self._session_router.get_replica( - # healthy_replicas, sess_id, self._session_replica_map - # ) - - # # Stateless: batching - # if self._max_batch_size > 1: - # fut = asyncio.Future() - # self._batch_queue.put_nowait(fut) - # return await fut - # else: - # # No batching, pick immediately - # healthy_replicas = self._get_healthy_replicas() - # return self._default_router.get_replica(healthy_replicas) - async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop diff --git a/src/forge/types.py b/src/forge/types.py index adb49c364..cc41d2185 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -118,8 +118,6 @@ class ServiceConfig: health_poll_rate: float = 0.2 replica_max_concurrent_requests: int = 10 return_first_rank_result: bool = True - max_batch_size: int = 1 - batch_max_wait_s: float = 0.01 def to_process_config(self) -> ProcessConfig: """Extract ProcessConfig from this ServiceConfig. From 2464ca85c96b0beedc54c6492496487019c1ee39 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 20:49:59 -0700 Subject: [PATCH 09/32] separate batch routing logic to BatchedServiceEndpoint --- src/forge/controller/service/endpoint.py | 166 ++++++++++++---------- src/forge/controller/service/interface.py | 27 ++-- 2 files changed, 103 insertions(+), 90 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index b74a6af6c..4d97c2ab2 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -37,8 +37,6 @@ def __init__( service, endpoint_name: str, router: str = "round_robin", - batch_size: int = 1, - batch_timeout: float = 0.1, ): self.service = service self.endpoint_name = endpoint_name @@ -46,14 +44,6 @@ def __init__( self.router = self._resolve_router(router) self.session_router = SessionRouter(fallback_router=self.router) - self.batch_size = batch_size - self.batch_timeout = batch_timeout - self._running_batch_loop = False - self._batch_queue: asyncio.Queue = asyncio.Queue() - if self.batch_size > 1: - self._running_batch_loop = True - self.batch_task = asyncio.create_task(self._batch_loop()) - self.max_attempts = 1 # number of tries for routing = initial + retries def _resolve_router(self, router_name: str) -> Router: @@ -107,64 +97,12 @@ async def _choose_replica(self, sess_id: str | None) -> "Replica": return self.session_router.get_replica( healthy, sess_id, self.service._session_replica_map ) - # Stateless: batching - if self.batch_size > 1: - fut = asyncio.Future() - self._batch_queue.put_nowait(fut) - return await fut - - # No batching, pick immediately - healthy = self.service._get_healthy_replicas() - return self.router.get_replica(healthy) - - async def _batch_loop(self): - """Background task that continuously processes batches of routing requests. - - This is the core batching logic that runs in a separate asyncio task. - It collects requests from the queue and processes them in batches based - on size and time constraints. - - The loop follows these steps: - 1. Wait for the first request to start a new batch - 2. Collect additional requests until batch_size or batch_timeout is reached - 3. Make a single routing decision for the entire batch - 4. Fulfill all futures with the selected replica - - This process repeats indefinitely until the task is cancelled. - """ - while self._running_batch_loop: - batch_futs = [] - - # Wait for first request - fut = await self._batch_queue.get() - batch_futs.append(fut) - start_time = time.monotonic() - - while True: - try: - timeout = max( - 0, self.batch_timeout - (time.monotonic() - start_time) - ) - fut = await asyncio.wait_for( - self._batch_queue.get(), timeout - ) # wait for timeout or until self._queue.get() finishes - batch_futs.append(fut) - - if len(batch_futs) >= self.batch_size: - break - except asyncio.TimeoutError: - break - - healthy_replicas = self.service._get_healthy_replicas() - - # One routing decision for the whole batch - replica = self.router.get_replica( - healthy_replicas, None, self.service._session_replica_map - ) - # Fulfill all futures with the chosen replica - for fut in batch_futs: - fut.set_result(replica) + # Use router to choose a replica + healthy_replicas = self.service._get_healthy_replicas() + return self.router.get_replica( + healthy_replicas, None, self.service._session_replica_map + ) async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" @@ -202,20 +140,17 @@ async def generate(self, *args: P.args, **kwargs: P.kwargs): ) async def stop(self): - """Stop the batching loop.""" - self._running_batch_loop = False + """Stop the service endpoint. + + For plain ServiceEndpoint (non-batched), this is a no-op. + """ + return class BatchedServiceEndpoint(ServiceEndpoint[P, R]): """ A ServiceEndpoint that supports request batch routing. - Args: - router: The underlying Router instance used to make routing decisions - session_router: The fallback Router for session-based routing. - batch_max_size: Maximum number of requests to collect in a single batch (default: 8) - batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01) - Features: - Maintains a batch queue - Spawns a background task to group requests into batches @@ -226,12 +161,85 @@ def __init__( service, endpoint_name: str, router: str = "round_robin", - session_router: str = "leastloaded", - batch_size: int = 1, - batch_timeout: float = 0.1, + batch_size: int = 8, + batch_timeout: float = 0.01, ): + super().__init__(service, endpoint_name, router=router) + + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self._batch_queue: asyncio.Queue = asyncio.Queue() + self._running_batch_loop = True + self.batch_task = asyncio.create_task(self._batch_loop()) + + async def _choose_replica(self, sess_id: str | None) -> "Replica": + """Get a replica for the given session ID.""" + + # Stateful routing always uses session router + if sess_id: + healthy = self.service._get_healthy_replicas() + return self.session_router.get_replica( + healthy, sess_id, self.service._session_replica_map + ) + # Stateless: batching + fut = asyncio.Future() + self._batch_queue.put_nowait(fut) + return await fut + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_size or batch_timeout is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running_batch_loop: + batch_futs = [] + + # Wait for first request + fut = await self._batch_queue.get() + batch_futs.append(fut) + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self.batch_timeout - (time.monotonic() - start_time) + ) + fut = await asyncio.wait_for( + self._batch_queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch_futs.append(fut) + + if len(batch_futs) >= self.batch_size: + break + except asyncio.TimeoutError: + break + + healthy_replicas = self.service._get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.router.get_replica( + healthy_replicas, None, self.service._session_replica_map + ) + + # Fulfill all futures with the chosen replica + for fut in batch_futs: + fut.set_result(replica) - super().__init__(service, endpoint_name) + async def stop(self): + """Stop the batching loop.""" + self._running_batch_loop = False + if hasattr(self, "batch_task"): + self.batch_task.cancel() class ServiceEndpointV2(Generic[P, R]): diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 738882ef8..cbab50cdf 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -15,7 +15,7 @@ from monarch._src.actor.endpoint import EndpointProperty -from .endpoint import ServiceEndpoint, ServiceEndpointV2 +from .endpoint import BatchedServiceEndpoint, ServiceEndpoint, ServiceEndpointV2 @dataclass @@ -89,13 +89,20 @@ def __init__(self, _service, actor_def): # Decorated with @service_endpoint # Create a ServiceEndpoint with batch routing config cfg = attr_value._service_endpoint_config - endpoint = ServiceEndpoint( - self._service, - attr_name, - router=cfg["router"], - batch_size=cfg["batch_size"], - batch_timeout=cfg["batch_timeout"], - ) + if cfg["batch_size"] > 1: + endpoint = BatchedServiceEndpoint( + self._service, + attr_name, + router=cfg["router"], + batch_size=cfg["batch_size"], + batch_timeout=cfg["batch_timeout"], + ) + else: + endpoint = ServiceEndpoint( + self._service, + attr_name, + router=cfg["router"], + ) elif isinstance(attr_value, EndpointProperty): # Decorated with @endpoint @@ -103,9 +110,7 @@ def __init__(self, _service, actor_def): endpoint = ServiceEndpoint(self._service, attr_name) else: # Not decorated with @endpoint or @service_endpoint - raise ValueError( - f"Attribute '{attr_name}' is not a valid service endpoint" - ) + continue setattr(self, attr_name, endpoint) # Session management methods - handled by ServiceInterface From 926c6015b9f4d6b35bcaf4c5853420bae916a4e6 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 21:00:22 -0700 Subject: [PATCH 10/32] add docstring --- src/forge/controller/service/endpoint.py | 95 +++++++++++++++++++----- 1 file changed, 77 insertions(+), 18 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 4d97c2ab2..947b48926 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -26,10 +26,30 @@ class ServiceEndpoint(Generic[P, R]): """ This extends Monarch's actor APIs for service endpoints. - - `route(*args, **kwargs)`: Routes the request to a single replica. - - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. - - Monarch's native actor APIs do not apply for services. + ServiceEndpoint provides the basic, non-batched routing API for Forge services. + + Args: + service: The underlying service object that owns replicas. + endpoint_name (str): The name of the endpoint method. + router (str, optional): Routing strategy for stateless requests. + Supported values: + - "round_robin": cycle through replicas in order. + - "leastloaded": pick the replica with the lowest load. + Default: "round_robin". + + Supported methods: + - `route`: Send a request to a single replica, chosen by the configured router + (e.g. round-robin, least-loaded). + - `fanout`: Broadcasts the request to all healthy replicas. + + Notes: + - Support `@endpoint()` and `@service_endpoint(router='..')` decorators. + - To specify router, use `@service_endpoint(router='..')`. + - Retry logic: If `max_attempts > 1`, failed calls may be retried on a different replica + if the first one becomes unhealthy. + - Session-aware routing: If a `sess_id` is provided, requests are routed via + `SessionRouter` for sticky session behavior. + - Monarch's native actor APIs do not apply for services. """ def __init__( @@ -41,10 +61,14 @@ def __init__( self.service = service self.endpoint_name = endpoint_name + # Primary router (stateless routing) self.router = self._resolve_router(router) + + # Session-aware router for sticky sessions self.session_router = SessionRouter(fallback_router=self.router) - self.max_attempts = 1 # number of tries for routing = initial + retries + # Number of routing attempts (initial + retries) + self.max_attempts = 1 def _resolve_router(self, router_name: str) -> Router: """Convert a router name into a router object. @@ -64,7 +88,12 @@ def _resolve_router(self, router_name: str) -> Router: ) async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - """Chooses a replica to call based on context and load balancing strategy.""" + """ + Route a single request to one replica. + + Retries up to `self.max_attempts` times if the chosen replica fails + and is marked unhealthy. Sticky session mapping is cleared on retry. + """ # Extract sess_id from kwargs if present sess_id = kwargs.pop("sess_id", None) @@ -89,7 +118,12 @@ async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: raise async def _choose_replica(self, sess_id: str | None) -> "Replica": - """Get a replica for the given session ID.""" + """ + Select a replica to handle the request. + + - If `sess_id` is provided, use the session router for sticky sessions. + - Otherwise, use the stateless router to pick among healthy replicas. + """ # Stateful routing always uses session router if sess_id: @@ -149,11 +183,32 @@ async def stop(self): class BatchedServiceEndpoint(ServiceEndpoint[P, R]): """ - A ServiceEndpoint that supports request batch routing. - - Features: - - Maintains a batch queue - - Spawns a background task to group requests into batches + A ServiceEndpoint variant that supports request batching. + + Args: + service: The underlying service object that owns replicas. + endpoint_name (str): The name of the endpoint method. + router (str, optional): Routing strategy for stateless requests. + Supported values: + - "round_robin": cycle through replicas in order. + - "leastloaded": pick the replica with the lowest load. + Default: "round_robin". + batch_size (int, optional): Maximum number of requests to group together + in a single batch before dispatching. Default: 8. + batch_timeout (float, optional): Maximum time (in seconds) to wait before + dispatching a batch. Default: 0.01. + + Key features: + - Collects requests into batches of up to `batch_size`. + - Uses a background asyncio task (`_batch_loop`) to manage the queue. + - Makes one routing decision per batch, and assigns the chosen replica + to all requests in that batch. + - Provides the same API (`route`, `fanout`, `stop`) as ServiceEndpoint. + + Usage: + class MyForgeActor(ForgeActor): + @service_endpoint(router="round_robin", batch_size=16, batch_timeout=0.05) + async def forward(self, x): ... """ def __init__( @@ -173,7 +228,13 @@ def __init__( self.batch_task = asyncio.create_task(self._batch_loop()) async def _choose_replica(self, sess_id: str | None) -> "Replica": - """Get a replica for the given session ID.""" + """ + Overridden to support batching. + + - Session requests bypass batching and use sticky session router. + - Stateless requests are enqueued; the batch loop will fulfill their + Future with a chosen replica. + """ # Stateful routing always uses session router if sess_id: @@ -204,11 +265,12 @@ async def _batch_loop(self): while self._running_batch_loop: batch_futs = [] - # Wait for first request + # Wait for the first request to start a batch fut = await self._batch_queue.get() batch_futs.append(fut) start_time = time.monotonic() + # Collect additional requests until batch size or timeout while True: try: timeout = max( @@ -224,9 +286,8 @@ async def _batch_loop(self): except asyncio.TimeoutError: break + # Make one routing decision for the batch healthy_replicas = self.service._get_healthy_replicas() - - # One routing decision for the whole batch replica = self.router.get_replica( healthy_replicas, None, self.service._session_replica_map ) @@ -278,7 +339,6 @@ async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: def service_endpoint( *, router="round_robin", - session_router="leastloaded", batch_size=1, batch_timeout=0.01, propagate=None, @@ -300,7 +360,6 @@ def decorator(method): ) ep._service_endpoint_config = dict( router=router, - session_router=session_router, batch_size=batch_size, batch_timeout=batch_timeout, ) From 4ca60ba2b45b62248af85551c6e6d3752f282dd2 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 21:31:38 -0700 Subject: [PATCH 11/32] add a test case --- tests/unit_tests/test_router.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 861bdf271..18489f333 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -76,6 +76,22 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: return replica +@pytest.mark.timeout(10) +@pytest.mark.asyncio +async def test_service_endpoint_retry_succeeds_on_second_attempt(): + """Ensure that retry logic executes when the first replica fails.""" + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # Allow retries + service.max_attempts = 2 + + result = await service.flaky.route() + assert result == "ok" # success after retry + finally: + await service.shutdown() + + # Router Tests From 0131e21524d7ca9a1196c7dd27f8e8ed118053fa Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 25 Sep 2025 21:38:41 -0700 Subject: [PATCH 12/32] correct test case --- tests/unit_tests/test_router.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 18489f333..98fa4a539 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -76,20 +76,21 @@ def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: return replica -@pytest.mark.timeout(10) @pytest.mark.asyncio -async def test_service_endpoint_retry_succeeds_on_second_attempt(): - """Ensure that retry logic executes when the first replica fails.""" - service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) +@pytest.mark.timeout(10) +async def test_service_as_actor_preserves_normal_usage(): + """Ensure that using `as_actor` does not break normal semantics.""" + service = await Counter.as_actor(5) try: - # Allow retries - service.max_attempts = 2 + assert await service.value.choose() == 5 + + # Test increment + await service.rr_incr.choose() + assert await service.value.choose() == 6 - result = await service.flaky.route() - assert result == "ok" # success after retry finally: - await service.shutdown() + await Counter.shutdown(service) # Router Tests From 93d8c9def39d7f2fa08cf2269b2860fd482a178a Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Fri, 26 Sep 2025 10:46:41 -0700 Subject: [PATCH 13/32] Update src/forge/controller/service/endpoint.py Co-authored-by: Allen Wang <9057208+allenwang28@users.noreply.github.com> --- src/forge/controller/service/endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 947b48926..e318baa80 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -43,7 +43,7 @@ class ServiceEndpoint(Generic[P, R]): - `fanout`: Broadcasts the request to all healthy replicas. Notes: - - Support `@endpoint()` and `@service_endpoint(router='..')` decorators. + - Supports Monarch's `@endpoint()` as well as service's `@service_endpoint(router='..')` decorators. - To specify router, use `@service_endpoint(router='..')`. - Retry logic: If `max_attempts > 1`, failed calls may be retried on a different replica if the first one becomes unhealthy. From 90e94b9cf22f3ade65d9fa311e558d2300399269 Mon Sep 17 00:00:00 2001 From: Danning XIE <24580222+DNXie@users.noreply.github.com> Date: Fri, 26 Sep 2025 10:46:51 -0700 Subject: [PATCH 14/32] Update src/forge/controller/service/interface.py Co-authored-by: Allen Wang <9057208+allenwang28@users.noreply.github.com> --- src/forge/controller/service/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index cbab50cdf..cc4cc68ba 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -105,7 +105,7 @@ def __init__(self, _service, actor_def): ) elif isinstance(attr_value, EndpointProperty): - # Decorated with @endpoint + # This was defined as a standard Monarch endpoint # Create a ServiceEndpoint that will route through the Service Actor endpoint = ServiceEndpoint(self._service, attr_name) else: From 4393a514ddead4e9a9855ed6d7f7fe7f93dc7de7 Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 26 Sep 2025 10:54:10 -0700 Subject: [PATCH 15/32] resolve comments --- src/forge/controller/service/endpoint.py | 69 +++++++----------------- tests/unit_tests/test_router.py | 2 +- 2 files changed, 19 insertions(+), 52 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index e318baa80..9a503a89a 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -17,7 +17,7 @@ from .replica import Replica -from .router import LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter +from .router import RoundRobinRouter, Router, SessionRouter P = ParamSpec("P") R = TypeVar("R") @@ -25,44 +25,28 @@ class ServiceEndpoint(Generic[P, R]): """ - This extends Monarch's actor APIs for service endpoints. - ServiceEndpoint provides the basic, non-batched routing API for Forge services. + ServiceEndpoint extends Monarch's native actor APIs for service functionality. - Args: - service: The underlying service object that owns replicas. - endpoint_name (str): The name of the endpoint method. - router (str, optional): Routing strategy for stateless requests. - Supported values: - - "round_robin": cycle through replicas in order. - - "leastloaded": pick the replica with the lowest load. - Default: "round_robin". + Services provide fault tolerance and load balancing on top of Monarch actors, exposing the following endpoints: - Supported methods: - - `route`: Send a request to a single replica, chosen by the configured router + - `route`: Send a request to a single replica, chosen by the configured router (e.g. round-robin, least-loaded). - - `fanout`: Broadcasts the request to all healthy replicas. - - Notes: - - Supports Monarch's `@endpoint()` as well as service's `@service_endpoint(router='..')` decorators. - - To specify router, use `@service_endpoint(router='..')`. - - Retry logic: If `max_attempts > 1`, failed calls may be retried on a different replica - if the first one becomes unhealthy. - - Session-aware routing: If a `sess_id` is provided, requests are routed via - `SessionRouter` for sticky session behavior. - - Monarch's native actor APIs do not apply for services. + - `fanout`: Broadcasts the request to all healthy replicas. + + Note that Monarch's native actor APIs are not accessible through service endpoints. """ def __init__( self, service, endpoint_name: str, - router: str = "round_robin", + router: Router = RoundRobinRouter(), ): self.service = service self.endpoint_name = endpoint_name # Primary router (stateless routing) - self.router = self._resolve_router(router) + self.router = router # Session-aware router for sticky sessions self.session_router = SessionRouter(fallback_router=self.router) @@ -70,23 +54,6 @@ def __init__( # Number of routing attempts (initial + retries) self.max_attempts = 1 - def _resolve_router(self, router_name: str) -> Router: - """Convert a router name into a router object. - - Args: - router_name (str): a router name. Supported routers: "round_robin", "leastloaded". - - Returns: - Router: A Router object. - """ - if router_name == "round_robin": - return RoundRobinRouter() - if router_name == "leastloaded": - return LeastLoadedRouter() - raise ValueError( - f"Unknown router name: {router_name}. Supported routers: 'round_robin', 'leastloaded'." - ) - async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: """ Route a single request to one replica. @@ -188,11 +155,11 @@ class BatchedServiceEndpoint(ServiceEndpoint[P, R]): Args: service: The underlying service object that owns replicas. endpoint_name (str): The name of the endpoint method. - router (str, optional): Routing strategy for stateless requests. + router (Router, optional): A Router object specifing routing strategy for stateless requests. Supported values: - - "round_robin": cycle through replicas in order. - - "leastloaded": pick the replica with the lowest load. - Default: "round_robin". + - RoundRobinRouter(): cycle through replicas in order. + - LeastLoadedRouter(): pick the replica with the lowest load. + Default: RoundRobinRouter(). batch_size (int, optional): Maximum number of requests to group together in a single batch before dispatching. Default: 8. batch_timeout (float, optional): Maximum time (in seconds) to wait before @@ -215,7 +182,7 @@ def __init__( self, service, endpoint_name: str, - router: str = "round_robin", + router: Router = RoundRobinRouter(), batch_size: int = 8, batch_timeout: float = 0.01, ): @@ -338,9 +305,9 @@ async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: def service_endpoint( *, - router="round_robin", - batch_size=1, - batch_timeout=0.01, + router: Router = RoundRobinRouter(), + batch_size: int = 1, + batch_timeout: float = 0.01, propagate=None, explicit_response_port=False, ): @@ -349,7 +316,7 @@ def service_endpoint( Example: class MyForgeActor(ForgeActor): - @service_endpoint(router="round_robin", batch_size=16, batch_timeout=0.05) + @service_endpoint(router=RoundRobinRouter(), batch_size=16, batch_timeout=0.05) async def predict(self, x): ... """ diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 98fa4a539..6fee199a8 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -56,7 +56,7 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: self.v += amount * multiplier return self.v - @service_endpoint(router="round_robin", batch_size=3, batch_timeout=1) + @service_endpoint(router=RoundRobinRouter(), batch_size=3, batch_timeout=1) async def rr_incr(self): """Increment using RoundRobin router.""" self.v += 1 From 653001e3faa24bc3efd01e6c139f762bd5757130 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 12:56:29 -0700 Subject: [PATCH 16/32] move batching logic back to Batcher class, keep router for each endpoint in service. TODO: add more test cases --- src/forge/controller/service/__init__.py | 7 +- src/forge/controller/service/endpoint.py | 206 +--------------------- src/forge/controller/service/interface.py | 32 +--- src/forge/controller/service/router.py | 127 ++++++++++++- src/forge/controller/service/service.py | 147 ++++++++++++++- tests/unit_tests/test_router.py | 17 +- 6 files changed, 297 insertions(+), 239 deletions(-) diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index 895437b19..cda6263a8 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -4,11 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .endpoint import BatchedServiceEndpoint, service_endpoint +from .endpoint import service_endpoint from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState -from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter +from .router import Batcher, LeastLoadedRouter, RoundRobinRouter, Router, SessionRouter from .service import Service, ServiceActor, ServiceConfig __all__ = [ @@ -26,5 +26,6 @@ "RoundRobinRouter", "SessionRouter", "service_endpoint", - "BatchedServiceEndpoint", + "Router", + "Batcher", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 9a503a89a..ff2529109 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -8,16 +8,12 @@ Service endpoint management for the Forge framework. """ -import asyncio -import time from typing import Generic, List, TypeVar from monarch.actor import endpoint from typing_extensions import ParamSpec -from .replica import Replica - -from .router import RoundRobinRouter, Router, SessionRouter +from .router import RoundRobinRouter, Router P = ParamSpec("P") R = TypeVar("R") @@ -25,85 +21,26 @@ class ServiceEndpoint(Generic[P, R]): """ - ServiceEndpoint extends Monarch's native actor APIs for service functionality. - - Services provide fault tolerance and load balancing on top of Monarch actors, exposing the following endpoints: - - - `route`: Send a request to a single replica, chosen by the configured router - (e.g. round-robin, least-loaded). - - `fanout`: Broadcasts the request to all healthy replicas. + This extends Monarch's actor APIs for service endpoints. + - `route(*args, **kwargs)`: Routes the request to a single replica. + - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. - Note that Monarch's native actor APIs are not accessible through service endpoints. + Monarch's native actor APIs do not apply for services. """ def __init__( self, service, endpoint_name: str, - router: Router = RoundRobinRouter(), ): self.service = service self.endpoint_name = endpoint_name - # Primary router (stateless routing) - self.router = router - - # Session-aware router for sticky sessions - self.session_router = SessionRouter(fallback_router=self.router) - - # Number of routing attempts (initial + retries) - self.max_attempts = 1 - async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: - """ - Route a single request to one replica. - - Retries up to `self.max_attempts` times if the chosen replica fails - and is marked unhealthy. Sticky session mapping is cleared on retry. - """ + """Chooses a replica to call based on context and load balancing strategy.""" # Extract sess_id from kwargs if present sess_id = kwargs.pop("sess_id", None) - - for attempt in range(self.max_attempts): - replica = await self._choose_replica(sess_id) - - # Wait for the result - try: - return await self.service._call( - replica, sess_id, self.endpoint_name, *args, **kwargs - ) - except Exception as e: - # If the replica failed, try to retry - if not replica.healthy and attempt < self.max_attempts - 1: - # Clear sticky mapping before retry - if ( - sess_id is not None - and sess_id in self.service._session_replica_map - ): - del self.service._session_replica_map[sess_id] - continue # retry with a fresh replica - raise - - async def _choose_replica(self, sess_id: str | None) -> "Replica": - """ - Select a replica to handle the request. - - - If `sess_id` is provided, use the session router for sticky sessions. - - Otherwise, use the stateless router to pick among healthy replicas. - """ - - # Stateful routing always uses session router - if sess_id: - healthy = self.service._get_healthy_replicas() - return self.session_router.get_replica( - healthy, sess_id, self.service._session_replica_map - ) - - # Use router to choose a replica - healthy_replicas = self.service._get_healthy_replicas() - return self.router.get_replica( - healthy_replicas, None, self.service._session_replica_map - ) + return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" @@ -140,135 +77,6 @@ async def generate(self, *args: P.args, **kwargs: P.kwargs): "Services only support route() and fanout()." ) - async def stop(self): - """Stop the service endpoint. - - For plain ServiceEndpoint (non-batched), this is a no-op. - """ - return - - -class BatchedServiceEndpoint(ServiceEndpoint[P, R]): - """ - A ServiceEndpoint variant that supports request batching. - - Args: - service: The underlying service object that owns replicas. - endpoint_name (str): The name of the endpoint method. - router (Router, optional): A Router object specifing routing strategy for stateless requests. - Supported values: - - RoundRobinRouter(): cycle through replicas in order. - - LeastLoadedRouter(): pick the replica with the lowest load. - Default: RoundRobinRouter(). - batch_size (int, optional): Maximum number of requests to group together - in a single batch before dispatching. Default: 8. - batch_timeout (float, optional): Maximum time (in seconds) to wait before - dispatching a batch. Default: 0.01. - - Key features: - - Collects requests into batches of up to `batch_size`. - - Uses a background asyncio task (`_batch_loop`) to manage the queue. - - Makes one routing decision per batch, and assigns the chosen replica - to all requests in that batch. - - Provides the same API (`route`, `fanout`, `stop`) as ServiceEndpoint. - - Usage: - class MyForgeActor(ForgeActor): - @service_endpoint(router="round_robin", batch_size=16, batch_timeout=0.05) - async def forward(self, x): ... - """ - - def __init__( - self, - service, - endpoint_name: str, - router: Router = RoundRobinRouter(), - batch_size: int = 8, - batch_timeout: float = 0.01, - ): - super().__init__(service, endpoint_name, router=router) - - self.batch_size = batch_size - self.batch_timeout = batch_timeout - self._batch_queue: asyncio.Queue = asyncio.Queue() - self._running_batch_loop = True - self.batch_task = asyncio.create_task(self._batch_loop()) - - async def _choose_replica(self, sess_id: str | None) -> "Replica": - """ - Overridden to support batching. - - - Session requests bypass batching and use sticky session router. - - Stateless requests are enqueued; the batch loop will fulfill their - Future with a chosen replica. - """ - - # Stateful routing always uses session router - if sess_id: - healthy = self.service._get_healthy_replicas() - return self.session_router.get_replica( - healthy, sess_id, self.service._session_replica_map - ) - # Stateless: batching - fut = asyncio.Future() - self._batch_queue.put_nowait(fut) - return await fut - - async def _batch_loop(self): - """Background task that continuously processes batches of routing requests. - - This is the core batching logic that runs in a separate asyncio task. - It collects requests from the queue and processes them in batches based - on size and time constraints. - - The loop follows these steps: - 1. Wait for the first request to start a new batch - 2. Collect additional requests until batch_size or batch_timeout is reached - 3. Make a single routing decision for the entire batch - 4. Fulfill all futures with the selected replica - - This process repeats indefinitely until the task is cancelled. - """ - while self._running_batch_loop: - batch_futs = [] - - # Wait for the first request to start a batch - fut = await self._batch_queue.get() - batch_futs.append(fut) - start_time = time.monotonic() - - # Collect additional requests until batch size or timeout - while True: - try: - timeout = max( - 0, self.batch_timeout - (time.monotonic() - start_time) - ) - fut = await asyncio.wait_for( - self._batch_queue.get(), timeout - ) # wait for timeout or until self._queue.get() finishes - batch_futs.append(fut) - - if len(batch_futs) >= self.batch_size: - break - except asyncio.TimeoutError: - break - - # Make one routing decision for the batch - healthy_replicas = self.service._get_healthy_replicas() - replica = self.router.get_replica( - healthy_replicas, None, self.service._session_replica_map - ) - - # Fulfill all futures with the chosen replica - for fut in batch_futs: - fut.set_result(replica) - - async def stop(self): - """Stop the batching loop.""" - self._running_batch_loop = False - if hasattr(self, "batch_task"): - self.batch_task.cancel() - class ServiceEndpointV2(Generic[P, R]): """An endpoint object specific to services. diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index cc4cc68ba..8ef768dae 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -15,7 +15,7 @@ from monarch._src.actor.endpoint import EndpointProperty -from .endpoint import BatchedServiceEndpoint, ServiceEndpoint, ServiceEndpointV2 +from .endpoint import ServiceEndpoint, ServiceEndpointV2 @dataclass @@ -87,30 +87,20 @@ def __init__(self, _service, actor_def): attr_value = getattr(actor_def, attr_name) if hasattr(attr_value, "_service_endpoint_config"): # Decorated with @service_endpoint - # Create a ServiceEndpoint with batch routing config cfg = attr_value._service_endpoint_config - if cfg["batch_size"] > 1: - endpoint = BatchedServiceEndpoint( - self._service, - attr_name, - router=cfg["router"], - batch_size=cfg["batch_size"], - batch_timeout=cfg["batch_timeout"], - ) - else: - endpoint = ServiceEndpoint( - self._service, - attr_name, - router=cfg["router"], - ) + # Service manages router creation + self._service._set_router(attr_name, cfg) + endpoint = ServiceEndpoint(self._service, attr_name) elif isinstance(attr_value, EndpointProperty): - # This was defined as a standard Monarch endpoint - # Create a ServiceEndpoint that will route through the Service Actor + # Decorated with Monarch @endpoint + self._service._set_router(attr_name) endpoint = ServiceEndpoint(self._service, attr_name) else: - # Not decorated with @endpoint or @service_endpoint + # Not an endpoint continue + + # Attach to interface setattr(self, attr_name, endpoint) # Session management methods - handled by ServiceInterface @@ -126,10 +116,6 @@ async def shutdown(self) -> None: """ Shut down the underlying Service and all endpoints. """ - for attr in dir(self): - ep = getattr(self, attr) - if isinstance(ep, ServiceEndpoint): - await ep.stop() await self._service.stop() def session(self) -> "SessionContext": diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index 6aba73533..c63880a78 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -5,9 +5,11 @@ # LICENSE file in the root directory of this source tree. +import asyncio import logging +import time from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Callable, Dict, List from .replica import Replica @@ -103,3 +105,126 @@ def get_replica( replica.idx, ) return replica + + +class Batcher: + """ + Asynchronous batching wrapper around a Router. + + Instead of selecting a replica immediately, incoming requests are enqueued + and grouped into batches. Once a batch is ready (either reaching the maximum + size or exceeding the maximum wait time), the batcher makes a single routing + decision using the inner router. All requests in that batch are then resolved + with the same replica. + + This reduces router overhead by amortizing multiple requests into one decision. + + Args: + inner_router: The underlying Router used to pick a replica. + get_healthy_replicas: Callable that returns the current list of healthy replicas. + get_session_map: Callable that returns the session-to-replica mapping. + batch_size: Maximum number of requests to collect in a single batch + before routing (default: 8). + batch_timeout: Maximum time to wait (in seconds) before routing a batch, + even if batch_size is not reached (default: 0.01). + + Example: + rr_router = RoundRobinRouter() + batcher = Batcher( + rr_router, + get_healthy_replicas=service._get_healthy_replicas, + get_session_map=service._get_session_map, + batch_size=16, + batch_timeout=0.01, + ) + + # Enqueue a request and await the chosen replica + replica = await batcher.route() + """ + + def __init__( + self, + inner_router: Router, + get_healthy_replicas: Callable[[], List["Replica"]], + get_session_map: Callable[[], Dict[str, int]], + batch_size: int = 16, + batch_timeout: float = 0.01, + ): + + self.inner_router = inner_router + self.batch_size = batch_size + self.batch_timeout = batch_timeout + self.get_healthy_replicas = get_healthy_replicas + self.get_session_map = get_session_map + + # Internal queue for batching routing requests + self._queue: asyncio.Queue = asyncio.Queue() + self._running = True # flag to control loop + # Background task that processes batches continuously + self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop()) + + async def _batch_loop(self): + """Background task that continuously processes batches of routing requests. + + This is the core batching logic that runs in a separate asyncio task. + It collects requests from the queue and processes them in batches based + on size and time constraints. + + The loop follows these steps: + 1. Wait for the first request to start a new batch + 2. Collect additional requests until batch_size or batch_timeout is reached + 3. Make a single routing decision for the entire batch + 4. Fulfill all futures with the selected replica + + This process repeats indefinitely until the task is cancelled. + """ + while self._running: + batch_futs = [] + + # Wait for first request + fut = await self._queue.get() + batch_futs.append(fut) + start_time = time.monotonic() + + while True: + try: + timeout = max( + 0, self.batch_timeout - (time.monotonic() - start_time) + ) + fut = await asyncio.wait_for( + self._queue.get(), timeout + ) # wait for timeout or until self._queue.get() finishes + batch_futs.append(fut) + + if len(batch_futs) >= self.batch_size: + break + except asyncio.TimeoutError: + break + + session_map = self.get_session_map() + healthy_replicas = self.get_healthy_replicas() + + # One routing decision for the whole batch + replica = self.inner_router.get_replica(healthy_replicas, None, session_map) + + # Fulfill all futures with the chosen replica + for fut in batch_futs: + fut.set_result(replica) + + async def route(self) -> Replica: + """Enqueue request and wait until batch assigns a replica.""" + fut = asyncio.Future() + # Queue the request for batching - this is non-blocking + self._queue.put_nowait(fut) + + # Wait for the batch processor to resolve our future + return await fut + + async def stop(self): + """Stop the batch loop gracefully.""" + self._running = False + self._batch_task.cancel() + try: + await self._batch_task + except asyncio.CancelledError: + pass diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index be024ffad..853e04320 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -46,8 +46,9 @@ from forge.controller.service.replica import Replica, ServiceRequest from forge.controller.service.router import ( - LeastLoadedRouter, + Batcher, RoundRobinRouter, + Router, SessionRouter, ) from forge.types import ServiceConfig @@ -108,7 +109,8 @@ async def __initialize__(self): # Initialize the routers self._default_router = RoundRobinRouter() - self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) + self._session_router = SessionRouter(fallback_router=self._default_router) + self.routers: dict[str, Router | Batcher] = {} # Initialize all replicas replicas = [] @@ -138,11 +140,88 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - async def _call( - self, replica: "Replica", sess_id: str | None, function: str, *args, **kwargs - ): - """Send request directly to a chosen replica and wait for result.""" + def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: + """ + Ensure a router exists for the given endpoint. + + - If a router is already set, leave it unchanged. + - If cfg is provided, use its router/batching options. + - If cfg is missing or incomplete, fall back to defaults: + use a round robin router without batching. + + Args: + endpoint_name: Name of the endpoint (string). + cfg: Optional service_endpoint_config dict, may include: + { + "router": Router, + "batch_size": int, + "batch_timeout": float + } + Returns: + Router | Batcher instance stored in self.routers + """ + + # If router already exists, ignore + if endpoint_name in self.routers: + return + + # Resolve base router + if cfg and "router" in cfg: + if not isinstance(cfg.get("router"), Router): + raise ValueError(f"Unknown router type: {cfg['router']}") + else: + base_router = cfg["router"] + else: + base_router = RoundRobinRouter() + + # Wrap in Batcher if batching requested + if cfg and cfg.get("batch_size", 1) > 1: + router = Batcher( + base_router, + get_healthy_replicas=self._get_healthy_replicas, + get_session_map=self._get_session_map, + batch_size=cfg.get("batch_size", 16), + batch_timeout=cfg.get("batch_timeout", 0.01), + ) + else: + router = base_router + + # Store and return + self.routers[endpoint_name] = router + + async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + """ + Routes a function call to the appropriate replica with load balancing and fault tolerance. + + This is the core routing method that handles: + - Session-based routing for stateful calls + - Round-robin load balancing for stateless calls + - Custom routing based on context hints + - Automatic retry on replica failures + - Request queuing and processing + + Args: + sess_id: Optional session ID for stateful routing + function: Name of the actor endpoint to call + *args: Positional arguments to pass to the endpoint + **kwargs: Keyword arguments to pass to the endpoint + + Returns: + The result from the actor endpoint execution + + Raises: + RuntimeError: If no healthy replicas are available + Exception: Any exception raised by the actor endpoint + """ + # Check context variables for session state if no explicit sess_id + if sess_id is None: + ctx = _session_context.get(None) + if ctx: + sess_id = ctx["session_id"] + replica = await self._get_replica(sess_id=sess_id, endpoint_name=function) + + # Create a ServiceRequest object to queue request = ServiceRequest( session_id=sess_id, function=function, @@ -154,7 +233,19 @@ async def _call( # Queue the request using replica's method await replica.enqueue_request(request) - return await request.future + # Wait for the result + try: + return await request.future + except Exception as e: + # If the replica failed, try to retry once + if not replica.healthy: + logger.debug( + f"Replica {replica.idx} failed during request, retrying on healthy replica. Exception: {e}" + ) + return await self._retry_request_on_healthy_replica( + sess_id, function, *args, **kwargs + ) + raise async def call_all(self, function: str, *args, **kwargs) -> List: """ @@ -207,6 +298,17 @@ async def call_all(self, function: str, *args, **kwargs) -> List: return results + async def _retry_request_on_healthy_replica( + self, sess_id: str | None, function: str, *args, **kwargs + ): + """Retries a failed request on a healthy replica.""" + # Force reassignment to a healthy replica (only for session-based calls) + if sess_id is not None and sess_id in self._session_replica_map: + del self._session_replica_map[sess_id] + + # Retry the call (this will assign to a new healthy replica) + return await self._call(sess_id, function, *args, **kwargs) + async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" migrated_requests = [] @@ -397,6 +499,9 @@ def _get_healthy_replicas(self) -> list[Replica]: """Returns a list of healthy replicas.""" return [r for r in self._replicas if r.healthy] + def _get_session_map(self) -> Dict[str, int]: + return self._session_replica_map + async def _health_loop(self, poll_rate_s: float): """Runs the health loop to monitor and recover replicas. @@ -425,6 +530,24 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) + async def _get_replica(self, sess_id: str | None, endpoint_name: str) -> "Replica": + """Get a replica for the given session ID.""" + healthy_replicas = [r for r in self._replicas if r.healthy] + router = self.routers.get(endpoint_name, self._default_router) + + # Case 1: sticky sessions + if sess_id is not None: + return self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + + # Case 2: batching + if isinstance(router, Batcher): + return await router.route() + + # Case 3: stateless routing + return self._default_router.get_replica(healthy_replicas) + async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop @@ -443,6 +566,16 @@ async def stop(self): except asyncio.CancelledError: logger.info("Health loop task cancelled.") + # Stop all batchers in routers + await asyncio.gather( + *( + router.stop() + for router in self.routers.values() + if isinstance(router, Batcher) + ), + return_exceptions=True, + ) + # Stop all replicas using their stop method await asyncio.gather( *[replica.stop() for replica in self._replicas], diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 6fee199a8..c4222269e 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ -Tests for router.py and batch routing in ServiceEndpoint +Tests for router.py """ import asyncio @@ -56,11 +56,16 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: self.v += amount * multiplier return self.v - @service_endpoint(router=RoundRobinRouter(), batch_size=3, batch_timeout=1) + @service_endpoint(router=RoundRobinRouter()) async def rr_incr(self): """Increment using RoundRobin router.""" self.v += 1 + @service_endpoint(router=RoundRobinRouter(), batch_size=3, batch_timeout=1) + async def rr_batch_incr(self): + """Increment using RoundRobin router.""" + self.v += 1 + def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: """Helper to build a replica with specified state and load.""" @@ -86,7 +91,7 @@ async def test_service_as_actor_preserves_normal_usage(): assert await service.value.choose() == 5 # Test increment - await service.rr_incr.choose() + await service.rr_batch_incr.choose() assert await service.value.choose() == 6 finally: @@ -136,7 +141,7 @@ async def test_round_robin_router_distribution(): # Make multiple sessionless calls using route() results = [] for _ in range(6): - await service.incr.route() + await service.rr_incr.route() values = await service.value.fanout() results.append(values) # Verify that requests were distributed round-robin @@ -158,7 +163,7 @@ async def test_round_robin_router_distribution_with_batching(): try: # Make multiple sessionless calls using route() results = [] - tasks = [service.rr_incr.route() for _ in range(6)] + tasks = [service.rr_batch_incr.route() for _ in range(6)] await asyncio.gather(*tasks) # Verify that requests were distributed round-robin # Each call increments a single replica, so after 6 calls we expect: @@ -211,7 +216,7 @@ async def test_service_endpoint_batch_flush_max_size(): try: # Make 3 concurrent requests (batch_size = 3) - tasks = [asyncio.create_task(service.rr_incr.route()) for _ in range(4)] + tasks = [asyncio.create_task(service.rr_batch_incr.route()) for _ in range(4)] await asyncio.gather(*tasks) values = await service.value.fanout() From 47e7f825259f9af51c83ce6935a91d9d7ebe1f6d Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 13:47:20 -0700 Subject: [PATCH 17/32] minor --- src/forge/controller/service/interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 8ef768dae..29a7fb05b 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -114,7 +114,7 @@ async def terminate_session(self, sess_id: str): async def shutdown(self) -> None: """ - Shut down the underlying Service and all endpoints. + Shut down the underlying Service. """ await self._service.stop() From baf2ef68d11916b5edbd2639a79a75b4bb097179 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 14:05:34 -0700 Subject: [PATCH 18/32] @service_endpoint returns ServiceEndpointProperty --- src/forge/controller/service/__init__.py | 3 +- src/forge/controller/service/endpoint.py | 42 ++++++++++++++++++----- src/forge/controller/service/interface.py | 30 +++++++--------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/forge/controller/service/__init__.py b/src/forge/controller/service/__init__.py index cda6263a8..d745f1556 100644 --- a/src/forge/controller/service/__init__.py +++ b/src/forge/controller/service/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .endpoint import service_endpoint +from .endpoint import service_endpoint, ServiceEndpointProperty from .interface import ServiceInterface, Session, SessionContext from .metrics import ServiceMetrics from .replica import Replica, ReplicaMetrics, ReplicaState @@ -26,6 +26,7 @@ "RoundRobinRouter", "SessionRouter", "service_endpoint", + "ServiceEndpointProperty", "Router", "Batcher", ] diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index ff2529109..a018222c9 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -8,15 +8,17 @@ Service endpoint management for the Forge framework. """ -from typing import Generic, List, TypeVar +from typing import Any, Generic, List, TypeVar + +from monarch._src.actor.endpoint import EndpointProperty -from monarch.actor import endpoint from typing_extensions import ParamSpec from .router import RoundRobinRouter, Router P = ParamSpec("P") R = TypeVar("R") +Propagator = Any class ServiceEndpoint(Generic[P, R]): @@ -111,6 +113,30 @@ async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: return result +class ServiceEndpointProperty(EndpointProperty, Generic[P, R]): + """ + Extension of EndpointProperty that carries service-specific + routing and batching configuration. + """ + + def __init__( + self, + method: Any, + propagator: Propagator, + explicit_response_port: bool, + *, + router: Router = RoundRobinRouter(), + batch_size: int = 1, + batch_timeout: float = 0.01, + ) -> None: + super().__init__(method, propagator, explicit_response_port) + self._service_endpoint_config = dict( + router=router, + batch_size=batch_size, + batch_timeout=batch_timeout, + ) + + def service_endpoint( *, router: Router = RoundRobinRouter(), @@ -128,16 +154,14 @@ class MyForgeActor(ForgeActor): async def predict(self, x): ... """ - def decorator(method): - # First wrap in EndpointProperty (so actor has a proper endpoint) - ep = endpoint( - method, propagate=propagate, explicit_response_port=explicit_response_port - ) - ep._service_endpoint_config = dict( + def decorator(method) -> ServiceEndpointProperty: + return ServiceEndpointProperty( + method, + propagator=propagate, + explicit_response_port=explicit_response_port, router=router, batch_size=batch_size, batch_timeout=batch_timeout, ) - return ep return decorator diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 29a7fb05b..4c4f5f93a 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -15,7 +15,7 @@ from monarch._src.actor.endpoint import EndpointProperty -from .endpoint import ServiceEndpoint, ServiceEndpointV2 +from .endpoint import ServiceEndpoint, ServiceEndpointProperty, ServiceEndpointV2 @dataclass @@ -85,23 +85,19 @@ def __init__(self, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if hasattr(attr_value, "_service_endpoint_config"): - # Decorated with @service_endpoint - cfg = attr_value._service_endpoint_config - # Service manages router creation + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + cfg = ( + attr_value._service_endpoint_config + if isinstance(attr_value, ServiceEndpointProperty) + else None + ) + + # Register router and attach endpoint self._service._set_router(attr_name, cfg) - endpoint = ServiceEndpoint(self._service, attr_name) - - elif isinstance(attr_value, EndpointProperty): - # Decorated with Monarch @endpoint - self._service._set_router(attr_name) - endpoint = ServiceEndpoint(self._service, attr_name) - else: - # Not an endpoint - continue - - # Attach to interface - setattr(self, attr_name, endpoint) + setattr(self, attr_name, ServiceEndpoint(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: From 595751e75b550f09a843668f75170bad815524b0 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 14:13:20 -0700 Subject: [PATCH 19/32] simplify _set_router --- src/forge/controller/service/service.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 853e04320..b3858d5e1 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -146,12 +146,12 @@ def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: - If a router is already set, leave it unchanged. - If cfg is provided, use its router/batching options. - - If cfg is missing or incomplete, fall back to defaults: + - If cfg is missing or incomplete, ignore and fall back to defaults: use a round robin router without batching. Args: endpoint_name: Name of the endpoint (string). - cfg: Optional service_endpoint_config dict, may include: + cfg: Optional service_endpoint_config dict, include: { "router": Router, "batch_size": int, @@ -165,17 +165,18 @@ def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: if endpoint_name in self.routers: return + # If config is missing or incomplete, use default router + if cfg is None or "router" not in cfg: + return + # Resolve base router - if cfg and "router" in cfg: - if not isinstance(cfg.get("router"), Router): - raise ValueError(f"Unknown router type: {cfg['router']}") - else: - base_router = cfg["router"] + if not isinstance(cfg.get("router"), Router): + raise ValueError(f"Unknown router type: {cfg['router']}") else: - base_router = RoundRobinRouter() + base_router = cfg["router"] # Wrap in Batcher if batching requested - if cfg and cfg.get("batch_size", 1) > 1: + if cfg.get("batch_size", 1) > 1: router = Batcher( base_router, get_healthy_replicas=self._get_healthy_replicas, From 0085972fddccbff92751464f5812de30386520ad Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 15:15:57 -0700 Subject: [PATCH 20/32] update docstring and test cases --- src/forge/controller/service/service.py | 15 +++--- tests/unit_tests/test_router.py | 63 ++++++++++++++++--------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index b3858d5e1..d3b5da86e 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -568,14 +568,13 @@ async def stop(self): logger.info("Health loop task cancelled.") # Stop all batchers in routers - await asyncio.gather( - *( - router.stop() - for router in self.routers.values() - if isinstance(router, Batcher) - ), - return_exceptions=True, - ) + # Stop all batchers + batchers = [ + router for router in self.routers.values() if isinstance(router, Batcher) + ] + if batchers: + await asyncio.gather(*(b.stop() for b in batchers), return_exceptions=True) + logger.info("All batcher loop(s) stopped gracefully.") # Stop all replicas using their stop method await asyncio.gather( diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index c4222269e..3bf63b932 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -34,11 +34,6 @@ class Counter(ForgeActor): def __init__(self, v: int): self.v = v - @endpoint - async def incr(self): - """Increment the counter.""" - self.v += 1 - @endpoint async def value(self) -> int: """Get the current counter value.""" @@ -56,14 +51,19 @@ async def add_to_value(self, amount: int, multiplier: int = 1) -> int: self.v += amount * multiplier return self.v - @service_endpoint(router=RoundRobinRouter()) - async def rr_incr(self): - """Increment using RoundRobin router.""" + @endpoint + async def incr(self): + """Increment the counter.""" self.v += 1 @service_endpoint(router=RoundRobinRouter(), batch_size=3, batch_timeout=1) - async def rr_batch_incr(self): - """Increment using RoundRobin router.""" + async def rr_batch_incr_bsize3(self): + """Increment the round-robin counter with batching (batch size = 3).""" + self.v += 1 + + @service_endpoint(router=RoundRobinRouter(), batch_size=5, batch_timeout=0.05) + async def rr_batch_incr_bsize5(self): + """Increment the round-robin counter with batching (batch size = 5).""" self.v += 1 @@ -91,7 +91,7 @@ async def test_service_as_actor_preserves_normal_usage(): assert await service.value.choose() == 5 # Test increment - await service.rr_batch_incr.choose() + await service.rr_batch_incr_bsize3.choose() assert await service.value.choose() == 6 finally: @@ -141,7 +141,7 @@ async def test_round_robin_router_distribution(): # Make multiple sessionless calls using route() results = [] for _ in range(6): - await service.rr_incr.route() + await service.incr.route() values = await service.value.fanout() results.append(values) # Verify that requests were distributed round-robin @@ -163,7 +163,7 @@ async def test_round_robin_router_distribution_with_batching(): try: # Make multiple sessionless calls using route() results = [] - tasks = [service.rr_batch_incr.route() for _ in range(6)] + tasks = [service.rr_batch_incr_bsize3.route() for _ in range(6)] await asyncio.gather(*tasks) # Verify that requests were distributed round-robin # Each call increments a single replica, so after 6 calls we expect: @@ -210,24 +210,45 @@ async def test_session_router_assigns_and_updates_session_map_in_service(): @pytest.mark.timeout(10) @pytest.mark.asyncio -async def test_service_endpoint_batch_flush_max_size(): - """Ensure @service_endpoint batching flushes correctly when max batch size reached.""" +async def test_independent_batchers_and_routers_per_endpoint(): + """Ensure multiple @service_endpoint endpoints coexist with independent routers/batchers.""" service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) try: - # Make 3 concurrent requests (batch_size = 3) - tasks = [asyncio.create_task(service.rr_batch_incr.route()) for _ in range(4)] + # --- First batch: rr_batch_incr_bsize3 (batch_size = 3) --- + tasks = [ + asyncio.create_task(service.rr_batch_incr_bsize3.route()) for _ in range(4) + ] await asyncio.gather(*tasks) values = await service.value.fanout() # Expectation: - # - 3 increments batched together on one replica - # - 1 increment on the other replica (new batch after flush) + # - First 3 requests form one batch → sent to replica R1 (+3). + # - Remaining 1 request forms its own batch → goes to replica R2 (+1). + # So totals should be [3, 1] (order depends on round robin). assert sum(values) == 4, f"Expected total=4, got {values}" - - # Exactly one replica should have count=3, and the other count=1 assert sorted(values) == [1, 3], f"Expected [1, 3], got {values}" + # --- Second batch: rr_batch_incr_bsize5 (batch_size = 5) --- + tasks = [ + asyncio.create_task(service.rr_batch_incr_bsize5.route()) for _ in range(7) + ] + await asyncio.gather(*tasks) + + values = await service.value.fanout() + + # Expectation (RoundRobin between replicas): + # Starting from previous state (R1=3, R2=1): + # - Next 5 requests form one batch → go to R1 (+5). + # - Remaining 2 requests form their own batch → go to R2 (+2). + # + # Final totals: + # R1 = 3 (previous) + 5 = 8 + # R2 = 1 (previous) + 2 = 3 + # So distribution should be [3, 8]. + assert sum(values) == 11, f"Expected total=11, got {values}" + assert sorted(values) == [3, 8], f"Expected [4, 8], got {values}" + finally: await service.shutdown() From 1bd0f912ba5c2d56aa73b57b6d61cb23b4246f28 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 17:24:37 -0700 Subject: [PATCH 21/32] add call/choose/call_one/... to ServiceEndpointV2 --- src/forge/controller/service/endpoint.py | 34 ++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index a018222c9..df5f6b146 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -97,7 +97,7 @@ def __init__(self, actor_mesh, endpoint_name: str): self.actor_mesh = actor_mesh self.endpoint_name = endpoint_name - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: """Chooses a replica to call based on context and load balancing strategy.""" # Extract sess_id from kwargs if present sess_id = kwargs.pop("sess_id", None) @@ -105,13 +105,43 @@ async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: sess_id, self.endpoint_name, *args, **kwargs ) - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" result = await self.actor_mesh.call_all.call_one( self.endpoint_name, *args, **kwargs ) return result + async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use choose() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use call() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def call_one(self, *args: P.args, **kwargs: P.kwargs) -> R: + raise NotImplementedError( + "You tried to use a call_one() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + raise NotImplementedError( + "You tried to use broadcast() on a service, not an actor. " + "Services only support route() and fanout()." + ) + + async def generate(self, *args: P.args, **kwargs: P.kwargs): + raise NotImplementedError( + "You tried to use generate() on a service, not an actor. " + "Services only support route() and fanout()." + ) + class ServiceEndpointProperty(EndpointProperty, Generic[P, R]): """ From 8b61802b07b9700c0c996bf97c0c75750b9d58aa Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 17:26:51 -0700 Subject: [PATCH 22/32] raise error if endpoint already exist in self.routers --- src/forge/controller/service/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index d3b5da86e..4a2b9e888 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -110,6 +110,8 @@ async def __initialize__(self): # Initialize the routers self._default_router = RoundRobinRouter() self._session_router = SessionRouter(fallback_router=self._default_router) + + # This keeps the map between the registered endpoints and the routers self.routers: dict[str, Router | Batcher] = {} # Initialize all replicas @@ -161,9 +163,9 @@ def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: Router | Batcher instance stored in self.routers """ - # If router already exists, ignore + # If router already exists, raise an exception if endpoint_name in self.routers: - return + raise ValueError(f"Router already exists for endpoint: {endpoint_name}") # If config is missing or incomplete, use default router if cfg is None or "router" not in cfg: From f4a60d80b6d02822b3093647d0b98bceb24cbbba Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 17:29:55 -0700 Subject: [PATCH 23/32] call->route; call_all -> fanout --- src/forge/controller/service/endpoint.py | 4 +- src/forge/controller/service/service.py | 6 +-- tests/unit_tests/tmp.py | 68 ++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 tests/unit_tests/tmp.py diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index df5f6b146..4a84bd8ed 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -42,11 +42,11 @@ async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: """Chooses a replica to call based on context and load balancing strategy.""" # Extract sess_id from kwargs if present sess_id = kwargs.pop("sess_id", None) - return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) + return await self.service._route(sess_id, self.endpoint_name, *args, **kwargs) async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" - result = await self.service.call_all(self.endpoint_name, *args, **kwargs) + result = await self.service._fanout(self.endpoint_name, *args, **kwargs) return result async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 4a2b9e888..8d0d546e5 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -192,7 +192,7 @@ def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: # Store and return self.routers[endpoint_name] = router - async def _call(self, sess_id: str | None, function: str, *args, **kwargs): + async def _route(self, sess_id: str | None, function: str, *args, **kwargs): """ Routes a function call to the appropriate replica with load balancing and fault tolerance. @@ -250,7 +250,7 @@ async def _call(self, sess_id: str | None, function: str, *args, **kwargs): ) raise - async def call_all(self, function: str, *args, **kwargs) -> List: + async def _fanout(self, function: str, *args, **kwargs) -> List: """ Broadcasts a function call to all healthy replicas and returns results as a list. @@ -310,7 +310,7 @@ async def _retry_request_on_healthy_replica( del self._session_replica_map[sess_id] # Retry the call (this will assign to a new healthy replica) - return await self._call(sess_id, function, *args, **kwargs) + return await self._route(sess_id, function, *args, **kwargs) async def _migrate_remaining_requests(self, failed_replica: Replica): """Migrates remaining requests from a failed replica to healthy replicas.""" diff --git a/tests/unit_tests/tmp.py b/tests/unit_tests/tmp.py new file mode 100644 index 000000000..cf28ec669 --- /dev/null +++ b/tests/unit_tests/tmp.py @@ -0,0 +1,68 @@ +import asyncio +import logging + +import pytest +from forge.controller import ForgeActor +from forge.controller.service import ( + LeastLoadedRouter, + Replica, + ReplicaState, + RoundRobinRouter, + ServiceConfig, + SessionRouter, +) +from forge.types import ProcessConfig +from monarch.actor import Actor, endpoint + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Counter(ForgeActor): + """Test actor that maintains a counter with various endpoints.""" + + def __init__(self, v: int): + self.v = v + + @endpoint + async def incr(self): + """Increment the counter.""" + self.v += 1 + + @endpoint + async def value(self) -> int: + """Get the current counter value.""" + return self.v + + @endpoint + async def fail_me(self): + """Endpoint that always fails to test error handling.""" + raise RuntimeError("I was asked to fail") + + @endpoint + async def slow_incr(self): + """Slow increment to test queueing.""" + await asyncio.sleep(1.0) + self.v += 1 + + @endpoint + async def add_to_value(self, amount: int, multiplier: int = 1) -> int: + """Add an amount (optionally multiplied) to the current value.""" + logger.info(f"adding {amount} with {multiplier}") + self.v += amount * multiplier + return self.v + + +async def test(): + service = await Counter.options(procs=1, num_replicas=1).as_service(1) + # async with service.session() as session: + # # All calls within this block use the same replica + # result1 = await service.incr.route() + # result2 = await service.value.fanout() + + session_id = await service.start_session() + result = await service.incr.route(sess_id=session_id) + await service.terminate_session(session_id) + + +asyncio.run(test()) From e9bd7c71be17db35a57a51bc2ac8de1c8066df97 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 17:33:45 -0700 Subject: [PATCH 24/32] move get_replica to route --- src/forge/controller/service/service.py | 39 ++++++++++++------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 8d0d546e5..dd40bec43 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -198,8 +198,7 @@ async def _route(self, sess_id: str | None, function: str, *args, **kwargs): This is the core routing method that handles: - Session-based routing for stateful calls - - Round-robin load balancing for stateless calls - - Custom routing based on context hints + - Per-endpoint router selection (round-robin, least-loaded, batching, etc.) - Automatic retry on replica failures - Request queuing and processing @@ -222,7 +221,23 @@ async def _route(self, sess_id: str | None, function: str, *args, **kwargs): if ctx: sess_id = ctx["session_id"] - replica = await self._get_replica(sess_id=sess_id, endpoint_name=function) + healthy_replicas = [r for r in self._replicas if r.healthy] + router = self.routers.get(function, self._default_router) + + # Select replica: + if sess_id is not None: + # Case 1: sticky sessions + replica = self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + + elif isinstance(router, Batcher): + # Case 2: batching + replica = await router.route() + + else: + # Case 3: stateless routing + replica = self._default_router.get_replica(healthy_replicas) # Create a ServiceRequest object to queue request = ServiceRequest( @@ -533,24 +548,6 @@ async def _health_loop(self, poll_rate_s: float): await asyncio.sleep(poll_rate_s) - async def _get_replica(self, sess_id: str | None, endpoint_name: str) -> "Replica": - """Get a replica for the given session ID.""" - healthy_replicas = [r for r in self._replicas if r.healthy] - router = self.routers.get(endpoint_name, self._default_router) - - # Case 1: sticky sessions - if sess_id is not None: - return self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - - # Case 2: batching - if isinstance(router, Batcher): - return await router.route() - - # Case 3: stateless routing - return self._default_router.get_replica(healthy_replicas) - async def stop(self): logger.debug("Stopping service...") # Signal shutdown to health loop From 8f1600632a73f9d9b9be60cb8cf34f98d019781e Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 29 Sep 2025 17:35:22 -0700 Subject: [PATCH 25/32] remove a tmp file (committed by mistake --- tests/unit_tests/tmp.py | 68 ----------------------------------------- 1 file changed, 68 deletions(-) delete mode 100644 tests/unit_tests/tmp.py diff --git a/tests/unit_tests/tmp.py b/tests/unit_tests/tmp.py deleted file mode 100644 index cf28ec669..000000000 --- a/tests/unit_tests/tmp.py +++ /dev/null @@ -1,68 +0,0 @@ -import asyncio -import logging - -import pytest -from forge.controller import ForgeActor -from forge.controller.service import ( - LeastLoadedRouter, - Replica, - ReplicaState, - RoundRobinRouter, - ServiceConfig, - SessionRouter, -) -from forge.types import ProcessConfig -from monarch.actor import Actor, endpoint - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class Counter(ForgeActor): - """Test actor that maintains a counter with various endpoints.""" - - def __init__(self, v: int): - self.v = v - - @endpoint - async def incr(self): - """Increment the counter.""" - self.v += 1 - - @endpoint - async def value(self) -> int: - """Get the current counter value.""" - return self.v - - @endpoint - async def fail_me(self): - """Endpoint that always fails to test error handling.""" - raise RuntimeError("I was asked to fail") - - @endpoint - async def slow_incr(self): - """Slow increment to test queueing.""" - await asyncio.sleep(1.0) - self.v += 1 - - @endpoint - async def add_to_value(self, amount: int, multiplier: int = 1) -> int: - """Add an amount (optionally multiplied) to the current value.""" - logger.info(f"adding {amount} with {multiplier}") - self.v += amount * multiplier - return self.v - - -async def test(): - service = await Counter.options(procs=1, num_replicas=1).as_service(1) - # async with service.session() as session: - # # All calls within this block use the same replica - # result1 = await service.incr.route() - # result2 = await service.value.fanout() - - session_id = await service.start_session() - result = await service.incr.route(sess_id=session_id) - await service.terminate_session(session_id) - - -asyncio.run(test()) From 03ff0c21bfcc5a14f0d0d4bce502de2f2fbdfa7b Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 12:16:30 -0700 Subject: [PATCH 26/32] remove dict for batcher config; add one more test for config --- src/forge/controller/service/endpoint.py | 8 ++--- src/forge/controller/service/interface.py | 12 +++---- src/forge/controller/service/service.py | 43 +++++++++++------------ tests/unit_tests/test_router.py | 43 +++++++++++++++++++++++ 4 files changed, 71 insertions(+), 35 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 4a84bd8ed..fd77c41fd 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -160,11 +160,9 @@ def __init__( batch_timeout: float = 0.01, ) -> None: super().__init__(method, propagator, explicit_response_port) - self._service_endpoint_config = dict( - router=router, - batch_size=batch_size, - batch_timeout=batch_timeout, - ) + self.router = router + self.batch_size = batch_size + self.batch_timeout = batch_timeout def service_endpoint( diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 4c4f5f93a..0d29ef425 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -89,14 +89,10 @@ def __init__(self, _service, actor_def): # ServiceEndpointProperty: created by @service_endpoint # EndpointProperty: created by @endpoint if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): - cfg = ( - attr_value._service_endpoint_config - if isinstance(attr_value, ServiceEndpointProperty) - else None - ) - - # Register router and attach endpoint - self._service._set_router(attr_name, cfg) + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + setattr(self, attr_name, ServiceEndpoint(self._service, attr_name)) # Session management methods - handled by ServiceInterface diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index dd40bec43..50a1ff408 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -40,11 +40,12 @@ from monarch.actor import Actor, endpoint +from forge.controller.service.endpoint import ServiceEndpointProperty + from forge.controller.service.interface import _session_context, Session from forge.controller.service.metrics import ServiceMetrics from forge.controller.service.replica import Replica, ServiceRequest - from forge.controller.service.router import ( Batcher, RoundRobinRouter, @@ -142,25 +143,21 @@ async def __initialize__(self): self._health_loop(poll_rate_s=self._cfg.health_poll_rate) ) - def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: + def _set_router( + self, endpoint_name: str, prop: ServiceEndpointProperty | None = None + ) -> None: """ Ensure a router exists for the given endpoint. - - If a router is already set, leave it unchanged. - - If cfg is provided, use its router/batching options. - - If cfg is missing or incomplete, ignore and fall back to defaults: - use a round robin router without batching. + - If a router is already set, raise an error. + - If a ServiceEndpointProperty is provided, construct its router/batcher + using the specified configuration. + - If not provided, fall back to a default round-robin router. Args: - endpoint_name: Name of the endpoint (string). - cfg: Optional service_endpoint_config dict, include: - { - "router": Router, - "batch_size": int, - "batch_timeout": float - } - Returns: - Router | Batcher instance stored in self.routers + endpoint_name: Name of the endpoint. + prop: Optional ServiceEndpointProperty object with router, batch_size, + and batch_timeout attributes. """ # If router already exists, raise an exception @@ -168,23 +165,25 @@ def _set_router(self, endpoint_name: str, cfg: dict | None = None) -> None: raise ValueError(f"Router already exists for endpoint: {endpoint_name}") # If config is missing or incomplete, use default router - if cfg is None or "router" not in cfg: + if prop is None or not isinstance(prop, ServiceEndpointProperty): return # Resolve base router - if not isinstance(cfg.get("router"), Router): - raise ValueError(f"Unknown router type: {cfg['router']}") + if not isinstance(prop.router, Router): + raise ValueError(f"Unknown router type: {prop.router}") else: - base_router = cfg["router"] + base_router = prop.router + batch_size = prop.batch_size + batch_timeout = prop.batch_timeout # Wrap in Batcher if batching requested - if cfg.get("batch_size", 1) > 1: + if batch_size > 1: router = Batcher( base_router, get_healthy_replicas=self._get_healthy_replicas, get_session_map=self._get_session_map, - batch_size=cfg.get("batch_size", 16), - batch_timeout=cfg.get("batch_timeout", 0.01), + batch_size=batch_size, + batch_timeout=batch_timeout, ) else: router = base_router diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 3bf63b932..1b2bbd523 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -13,6 +13,7 @@ import pytest from forge.controller import ForgeActor from forge.controller.service import ( + Batcher, LeastLoadedRouter, Replica, ReplicaState, @@ -66,6 +67,11 @@ async def rr_batch_incr_bsize5(self): """Increment the round-robin counter with batching (batch size = 5).""" self.v += 1 + @service_endpoint(router=RoundRobinRouter()) + async def rr_batch_incr_bsize1(self): + """Increment the round-robin counter with default batch_size=1.""" + self.v += 1 + def make_replica(idx: int, healthy: bool = True, load: int = 0) -> Replica: """Helper to build a replica with specified state and load.""" @@ -98,6 +104,43 @@ async def test_service_as_actor_preserves_normal_usage(): await Counter.shutdown(service) +@pytest.mark.asyncio +async def test_service_endpoint_router_and_configurations(): + """ + Verify service endpoints are registered with correct router/batching configuration: + - rr_batch_incr_bsize1: plain RoundRobinRouter, no batching (batch_size=1, timeout=0.01) + - rr_batch_incr_bsize3: Batcher wrapping RoundRobinRouter (batch_size=3, timeout=1) + - incr: plain @endpoint, should not appear in service.routers + """ + service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) + + try: + # --- rr_batch_incr_bsize1 --- + router1 = service.routers.get("rr_batch_incr_bsize1") + assert isinstance( + router1, RoundRobinRouter + ), f"Expected RoundRobinRouter, got {type(router1)}" + + prop1 = getattr(Counter, "rr_batch_incr_bsize1") + assert prop1.batch_size == 1 + assert prop1.batch_timeout == 0.01 + + # --- rr_batch_incr_bsize3 --- + router3 = service.routers.get("rr_batch_incr_bsize3") + + assert isinstance(router3, Batcher), f"Expected Batcher, got {type(router3)}" + assert router3.batch_size == 3 + assert router3.batch_timeout == 1 + + # --- incr --- + assert ( + "incr" not in service.routers + ), "Plain @endpoint should not be in service.routers" + + finally: + await service.shutdown() + + # Router Tests From 1faf6a674247116bcf977e0476342e979d14c450 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 12:20:00 -0700 Subject: [PATCH 27/32] add docstring to explain why ServiceEndpointProperty inherits EndpointProperty --- src/forge/controller/service/endpoint.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index fd77c41fd..4ba3cfb39 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -147,6 +147,10 @@ class ServiceEndpointProperty(EndpointProperty, Generic[P, R]): """ Extension of EndpointProperty that carries service-specific routing and batching configuration. + + Inherits from EndpointProperty so the method is still registered as + a valid actor endpoint, while also attaching service-specific options + (router, batch_size, batch_timeout). """ def __init__( From bcc35bbc90ee758d65dd285a41370a72b50b5fd4 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 12:32:32 -0700 Subject: [PATCH 28/32] fix lint --- tests/unit_tests/test_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 1b2bbd523..be33d30a5 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -121,7 +121,7 @@ async def test_service_endpoint_router_and_configurations(): router1, RoundRobinRouter ), f"Expected RoundRobinRouter, got {type(router1)}" - prop1 = getattr(Counter, "rr_batch_incr_bsize1") + prop1 = Counter.rr_batch_incr_bsize1 assert prop1.batch_size == 1 assert prop1.batch_timeout == 0.01 From ef1faa40e260319e47258b8d183f17ab60870778 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 13:53:14 -0700 Subject: [PATCH 29/32] change router from router obj to callable --- src/forge/controller/service/endpoint.py | 6 ++-- src/forge/controller/service/service.py | 15 ++++++---- tests/unit_tests/test_router.py | 36 ++++++++++++++++++++++-- 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/forge/controller/service/endpoint.py b/src/forge/controller/service/endpoint.py index 4ba3cfb39..c731017c6 100644 --- a/src/forge/controller/service/endpoint.py +++ b/src/forge/controller/service/endpoint.py @@ -8,7 +8,7 @@ Service endpoint management for the Forge framework. """ -from typing import Any, Generic, List, TypeVar +from typing import Any, Callable, Generic, List, TypeVar from monarch._src.actor.endpoint import EndpointProperty @@ -159,7 +159,7 @@ def __init__( propagator: Propagator, explicit_response_port: bool, *, - router: Router = RoundRobinRouter(), + router: Callable[[], Router] = RoundRobinRouter, batch_size: int = 1, batch_timeout: float = 0.01, ) -> None: @@ -171,7 +171,7 @@ def __init__( def service_endpoint( *, - router: Router = RoundRobinRouter(), + router: Callable[[], Router] = RoundRobinRouter, batch_size: int = 1, batch_timeout: float = 0.01, propagate=None, diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 50a1ff408..f72d30471 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -36,7 +36,7 @@ import logging import pprint import uuid -from typing import Dict, List +from typing import Callable, Dict, List from monarch.actor import Actor, endpoint @@ -162,20 +162,25 @@ def _set_router( # If router already exists, raise an exception if endpoint_name in self.routers: - raise ValueError(f"Router already exists for endpoint: {endpoint_name}") + raise AssertionError(f"Router already exists for endpoint: {endpoint_name}") # If config is missing or incomplete, use default router if prop is None or not isinstance(prop, ServiceEndpointProperty): return # Resolve base router - if not isinstance(prop.router, Router): - raise ValueError(f"Unknown router type: {prop.router}") + if not callable(prop.router): + raise ValueError(f"Router must be callable, got: {prop.router}") else: - base_router = prop.router + base_router = prop.router() # Call the router constructor batch_size = prop.batch_size batch_timeout = prop.batch_timeout + if not isinstance(base_router, Router): + raise ValueError( + f"Router must be a Router instance, got: {type(base_router)}" + ) + # Wrap in Batcher if batching requested if batch_size > 1: router = Batcher( diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index be33d30a5..4ce497e77 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -57,17 +57,17 @@ async def incr(self): """Increment the counter.""" self.v += 1 - @service_endpoint(router=RoundRobinRouter(), batch_size=3, batch_timeout=1) + @service_endpoint(router=RoundRobinRouter, batch_size=3, batch_timeout=1) async def rr_batch_incr_bsize3(self): """Increment the round-robin counter with batching (batch size = 3).""" self.v += 1 - @service_endpoint(router=RoundRobinRouter(), batch_size=5, batch_timeout=0.05) + @service_endpoint(router=RoundRobinRouter, batch_size=5, batch_timeout=0.05) async def rr_batch_incr_bsize5(self): """Increment the round-robin counter with batching (batch size = 5).""" self.v += 1 - @service_endpoint(router=RoundRobinRouter()) + @service_endpoint(router=RoundRobinRouter) async def rr_batch_incr_bsize1(self): """Increment the round-robin counter with default batch_size=1.""" self.v += 1 @@ -141,6 +141,36 @@ async def test_service_endpoint_router_and_configurations(): await service.shutdown() +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_noncallable(): + """@service_endpoint with non-callable router should raise ValueError.""" + + class BadActor(ForgeActor): + @service_endpoint(router="roundrobin") # string, not callable + async def bad_endpoint(self): + return 42 + + with pytest.raises(ValueError, match="Router must be callable"): + # Triggers ServiceInterface._set_router during construction + await BadActor.options(num_replicas=1).as_service() + + +@pytest.mark.asyncio +async def test_service_endpoint_with_invalid_router_wrong_return_type(): + """@service_endpoint with callable that doesn't return Router should raise ValueError.""" + + class NotARouter: + """Dummy class that is not a Router.""" + + class BadActor(ForgeActor): + @service_endpoint(router=NotARouter) # returns NotARouter + async def bad_endpoint(self): + return 123 + + with pytest.raises(ValueError, match="Router must be a Router instance"): + await BadActor.options(num_replicas=1).as_service() + + # Router Tests From e2aee83e97fa8ca4b1add3e2bb1407aff3ea5e1c Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 14:28:49 -0700 Subject: [PATCH 30/32] batch requests --- src/forge/controller/service/replica.py | 5 +++ src/forge/controller/service/router.py | 28 ++++++++--------- src/forge/controller/service/service.py | 41 +++++++++++++------------ 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index ae69d3df5..81c6de0e1 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -220,6 +220,11 @@ async def enqueue_request(self, request: ServiceRequest): # Accept requests in all other states - let the processing loop handle the rest await self.request_queue.put(request) + async def enqueue_batch(self, requests: list[ServiceRequest]): + """Enqueues a batch of requests for processing by this replica.""" + for req in requests: + await self.enqueue_request(req) + async def _process_single_request(self, request: ServiceRequest) -> bool: """Processes a single request and returns success status. diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index c63880a78..f063c8662 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from typing import Callable, Dict, List -from .replica import Replica +from forge.controller.service.replica import Replica, ServiceRequest logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -179,11 +179,9 @@ async def _batch_loop(self): This process repeats indefinitely until the task is cancelled. """ while self._running: - batch_futs = [] # Wait for first request - fut = await self._queue.get() - batch_futs.append(fut) + batch = [await self._queue.get()] start_time = time.monotonic() while True: @@ -191,12 +189,12 @@ async def _batch_loop(self): timeout = max( 0, self.batch_timeout - (time.monotonic() - start_time) ) - fut = await asyncio.wait_for( + req = await asyncio.wait_for( self._queue.get(), timeout ) # wait for timeout or until self._queue.get() finishes - batch_futs.append(fut) + batch.append(req) - if len(batch_futs) >= self.batch_size: + if len(batch) >= self.batch_size: break except asyncio.TimeoutError: break @@ -207,18 +205,20 @@ async def _batch_loop(self): # One routing decision for the whole batch replica = self.inner_router.get_replica(healthy_replicas, None, session_map) - # Fulfill all futures with the chosen replica - for fut in batch_futs: - fut.set_result(replica) + # Send whole batch to replica + try: + await replica.enqueue_batch(batch) + except Exception as e: + for req in batch: + req.future.set_exception(e) - async def route(self) -> Replica: + async def enqueue(self, request: ServiceRequest) -> Any: """Enqueue request and wait until batch assigns a replica.""" - fut = asyncio.Future() # Queue the request for batching - this is non-blocking - self._queue.put_nowait(fut) + self._queue.put_nowait(request) # Wait for the batch processor to resolve our future - return await fut + return await request.future async def stop(self): """Stop the batch loop gracefully.""" diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index f72d30471..6c182b5ae 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -36,7 +36,7 @@ import logging import pprint import uuid -from typing import Callable, Dict, List +from typing import Dict, List from monarch.actor import Actor, endpoint @@ -225,24 +225,8 @@ async def _route(self, sess_id: str | None, function: str, *args, **kwargs): if ctx: sess_id = ctx["session_id"] - healthy_replicas = [r for r in self._replicas if r.healthy] router = self.routers.get(function, self._default_router) - # Select replica: - if sess_id is not None: - # Case 1: sticky sessions - replica = self._session_router.get_replica( - healthy_replicas, sess_id, self._session_replica_map - ) - - elif isinstance(router, Batcher): - # Case 2: batching - replica = await router.route() - - else: - # Case 3: stateless routing - replica = self._default_router.get_replica(healthy_replicas) - # Create a ServiceRequest object to queue request = ServiceRequest( session_id=sess_id, @@ -252,8 +236,27 @@ async def _route(self, sess_id: str | None, function: str, *args, **kwargs): future=asyncio.Future(), ) - # Queue the request using replica's method - await replica.enqueue_request(request) + # Case: batching is enabled and no session ID (stateless calls only) + # Forward the request into the Batcher queue. The batcher will send + # the batched requests to the selected replica. + if sess_id is None and isinstance(router, Batcher): + await router.enqueue(request) + + else: + healthy_replicas = [r for r in self._replicas if r.healthy] + + # Select replica: + if sess_id is not None: + # Case 1: sticky sessions + replica = self._session_router.get_replica( + healthy_replicas, sess_id, self._session_replica_map + ) + else: + # Case 2: stateless routing + replica = self._default_router.get_replica(healthy_replicas) + + # Queue the request using replica's method + await replica.enqueue_request(request) # Wait for the result try: From 9060ec90c7c727b16020f11f7dde0b8c8a5803c1 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 14:30:17 -0700 Subject: [PATCH 31/32] add changes to ServiceInterfaceV2 for future adaption --- src/forge/controller/service/interface.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 0d29ef425..8e4a0c461 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -168,10 +168,15 @@ def __init__(self, _proc_mesh, _service, actor_def): # Inspect the actor_def directly to find endpoints for attr_name in dir(actor_def): attr_value = getattr(actor_def, attr_name) - if isinstance(attr_value, EndpointProperty): - # Create a ServiceEndpoint that will route through the Service Actor - endpoint = ServiceEndpointV2(self._service, attr_name) - setattr(self, attr_name, endpoint) + + # ServiceEndpointProperty: created by @service_endpoint + # EndpointProperty: created by @endpoint + if isinstance(attr_value, (EndpointProperty, ServiceEndpointProperty)): + if isinstance(attr_value, ServiceEndpointProperty): + # Register router with service-specific config + self._service._set_router(attr_name, attr_value) + + setattr(self, attr_name, ServiceEndpointV2(self._service, attr_name)) # Session management methods - handled by ServiceInterface async def start_session(self) -> str: From 57a1abe1162b2034e18685653dcf9d0df3e3a197 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 30 Sep 2025 16:40:25 -0700 Subject: [PATCH 32/32] fix missing import --- src/forge/controller/service/router.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/forge/controller/service/router.py b/src/forge/controller/service/router.py index f063c8662..11311f93f 100644 --- a/src/forge/controller/service/router.py +++ b/src/forge/controller/service/router.py @@ -9,7 +9,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List from forge.controller.service.replica import Replica, ServiceRequest @@ -138,8 +138,10 @@ class Batcher: batch_timeout=0.01, ) - # Enqueue a request and await the chosen replica - replica = await batcher.route() + request = ServiceRequest(...) + + # Enqueue a request to be sent to a replica + await batcher.enqueue(request) """ def __init__(