Skip to content

Commit 8bb3bc3

Browse files
committed
add batchrouter
1 parent 35d9a12 commit 8bb3bc3

File tree

5 files changed

+241
-16
lines changed

5 files changed

+241
-16
lines changed

src/forge/controller/service/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .interface import ServiceInterface, Session, SessionContext
88
from .metrics import ServiceMetrics
99
from .replica import Replica, ReplicaMetrics, ReplicaState
10-
from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter
10+
from .router import BatchRouter, LeastLoadedRouter, RoundRobinRouter, SessionRouter
1111
from .service import Service, ServiceActor, ServiceConfig
1212

1313
__all__ = [
@@ -24,4 +24,5 @@
2424
"LeastLoadedRouter",
2525
"RoundRobinRouter",
2626
"SessionRouter",
27+
"BatchRouter",
2728
]

src/forge/controller/service/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class Router(ABC):
286286
"""Abstract base class for routing logic."""
287287

288288
@abstractmethod
289-
def get_replica(
289+
async def get_replica(
290290
self,
291291
healthy_replicas: List[Replica],
292292
sess_id: str | None = None,

src/forge/controller/service/router.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
78
import logging
8-
from typing import Dict, List
9+
import time
10+
from typing import Dict, List, Optional
911

1012
from .interface import Router
1113
from .replica import Replica
@@ -20,7 +22,7 @@ class RoundRobinRouter(Router):
2022
def __init__(self):
2123
self._next_idx = 0
2224

23-
def get_replica(
25+
async def get_replica(
2426
self,
2527
healthy_replicas: List[Replica],
2628
sess_id: str | None = None,
@@ -38,7 +40,7 @@ def get_replica(
3840
class LeastLoadedRouter(Router):
3941
"""Always routes to the replica with the lowest current load."""
4042

41-
def get_replica(
43+
async def get_replica(
4244
self,
4345
healthy_replicas: List[Replica],
4446
sess_id: str | None = None,
@@ -55,7 +57,7 @@ class SessionRouter(Router):
5557
def __init__(self, fallback_router: Router):
5658
self.fallback_router = fallback_router
5759

58-
def get_replica(
60+
async def get_replica(
5961
self,
6062
healthy_replicas: List[Replica],
6163
sess_id: str | None = None,
@@ -78,7 +80,7 @@ def get_replica(
7880
del session_map[sess_id]
7981

8082
# Use fallback router to assign a new replica
81-
replica = self.fallback_router.get_replica(
83+
replica = await self.fallback_router.get_replica(
8284
healthy_replicas, sess_id, session_map
8385
)
8486
session_map[sess_id] = replica.idx
@@ -88,3 +90,112 @@ def get_replica(
8890
replica.idx,
8991
)
9092
return replica
93+
94+
95+
class BatchRouter:
96+
"""
97+
Router wrapper that batches routing decisions.
98+
Uses an inner router to pick the replica for each batch.
99+
100+
Args:
101+
inner_router: The underlying Router instance used to make routing decisions
102+
batch_max_size: Maximum number of requests to collect in a single batch (default: 8)
103+
batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01)
104+
105+
Example:
106+
rr_router = RoundRobinRouter()
107+
batch_router = BatchRouter(rr_router, batch_max_size=16, batch_max_wait_s=0.02)
108+
109+
replica = await batch_router.get_replica(healthy_replicas, sess_id, session_map)
110+
"""
111+
112+
def __init__(
113+
self,
114+
inner_router: Router,
115+
batch_max_size: int = 8,
116+
batch_max_wait_s: float = 0.01,
117+
):
118+
119+
self.inner_router = inner_router
120+
self.batch_max_size = batch_max_size
121+
self.batch_max_wait_s = batch_max_wait_s
122+
123+
# Internal queue for batching routing requests
124+
self._queue: asyncio.Queue = asyncio.Queue()
125+
# Background task that processes batches continuously
126+
self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop())
127+
128+
async def _batch_loop(self):
129+
"""Background task that continuously processes batches of routing requests.
130+
131+
This is the core batching logic that runs in a separate asyncio task.
132+
It collects requests from the queue and processes them in batches based
133+
on size and time constraints.
134+
135+
The loop follows these steps:
136+
1. Wait for the first request to start a new batch
137+
2. Collect additional requests until batch_max_size or batch_max_wait_s is reached
138+
3. Make a single routing decision for the entire batch
139+
4. Fulfill all futures with the selected replica
140+
141+
This process repeats indefinitely until the task is cancelled.
142+
"""
143+
while True:
144+
batch = []
145+
futs = []
146+
sess_ids = []
147+
start_time = time.time()
148+
149+
# Wait for first request
150+
fut, healthy_replicas, sess_id, session_map = await self._queue.get()
151+
batch.append((healthy_replicas, sess_id, session_map))
152+
futs.append(fut)
153+
sess_ids.append(sess_id)
154+
155+
while True:
156+
try:
157+
timeout = max(0, self.batch_max_wait_s - (time.time() - start_time))
158+
(
159+
fut,
160+
healthy_replicas,
161+
sess_id,
162+
session_map,
163+
) = await asyncio.wait_for(self._queue.get(), timeout)
164+
batch.append((healthy_replicas, sess_id, session_map))
165+
futs.append(fut)
166+
sess_ids.append(sess_id)
167+
168+
if len(batch) >= self.batch_max_size:
169+
break
170+
except asyncio.TimeoutError:
171+
break
172+
173+
# One routing decision for the whole batch
174+
healthy_replicas = batch[-1][0] # use most recent replica state
175+
session_map = batch[-1][2] # use most recent session map
176+
177+
# Check if any replicas have become unhealthy
178+
healthy_replicas = [r for r in healthy_replicas if r.healthy]
179+
replica = await self.inner_router.get_replica(
180+
healthy_replicas, None, session_map
181+
)
182+
183+
# Fulfill all futures with the chosen replica
184+
for fut in futs:
185+
fut.set_result(replica)
186+
187+
async def get_replica(
188+
self,
189+
healthy_replicas: List[Replica],
190+
sess_id: Optional[str] = None,
191+
session_map: Optional[Dict[str, int]] = None,
192+
) -> Replica:
193+
"""Enqueue request and wait until batch assigns a replica."""
194+
loop = asyncio.get_event_loop()
195+
fut = loop.create_future()
196+
197+
# Queue the request for batching - this is non-blocking
198+
self._queue.put_nowait((fut, healthy_replicas, sess_id, session_map))
199+
200+
# Wait for the batch processor to resolve our future
201+
return await fut

src/forge/controller/service/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ async def _get_replica(self, sess_id: str | None) -> "Replica":
477477
healthy_replicas = [r for r in self._replicas if r.healthy]
478478
if sess_id is None:
479479
# No session, use the default router
480-
return self._default_router.get_replica(healthy_replicas)
480+
return await self._default_router.get_replica(healthy_replicas)
481481

482-
return self._session_router.get_replica(
482+
return await self._session_router.get_replica(
483483
healthy_replicas, sess_id, self._session_replica_map
484484
)
485485

tests/unit_tests/test_service.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytest
1515
from forge.controller import ForgeActor
1616
from forge.controller.service import (
17+
BatchRouter,
1718
LeastLoadedRouter,
1819
Replica,
1920
ReplicaState,
@@ -668,7 +669,7 @@ async def test_least_loaded_router_basic():
668669
make_replica(2, load=3),
669670
]
670671
router = LeastLoadedRouter()
671-
chosen = router.get_replica(replicas)
672+
chosen = await router.get_replica(replicas)
672673
assert chosen.idx == 1 # lowest load
673674

674675

@@ -681,11 +682,11 @@ async def test_session_router_assigns_and_updates_session_map():
681682
router = SessionRouter(fallback)
682683

683684
# First request assigns via fallback
684-
r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map)
685+
r1 = await router.get_replica(replicas, sess_id="sess1", session_map=session_map)
685686
assert session_map["sess1"] == r1.idx
686687

687688
# Second request should stick
688-
r2 = router.get_replica(replicas, sess_id="sess1", session_map=session_map)
689+
r2 = await router.get_replica(replicas, sess_id="sess1", session_map=session_map)
689690
assert r1.idx == r2.idx
690691

691692

@@ -698,8 +699,8 @@ async def test_session_router_with_round_robin_fallback():
698699
fallback = RoundRobinRouter()
699700
router = SessionRouter(fallback)
700701

701-
r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map)
702-
r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map)
702+
r1 = await router.get_replica(replicas, sess_id="sess1", session_map=session_map)
703+
r2 = await router.get_replica(replicas, sess_id="sess2", session_map=session_map)
703704

704705
assert r1.idx != r2.idx
705706
assert set(session_map.values()) == {0, 1}
@@ -710,12 +711,124 @@ async def test_session_router_with_round_robin_fallback():
710711
fallback = LeastLoadedRouter()
711712
router = SessionRouter(fallback)
712713

713-
r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map)
714-
r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map)
714+
r1 = await router.get_replica(replicas, sess_id="sess1", session_map=session_map)
715+
r2 = await router.get_replica(replicas, sess_id="sess2", session_map=session_map)
715716

716717
assert r1.idx == r2.idx == 0
717718

718719

720+
@pytest.mark.asyncio
721+
async def test_batching_router_batchsize_with_roundrobin():
722+
"""Batch should flush when max batch size is reached using RoundRobinRouter."""
723+
replicas = [make_replica(0), make_replica(1)]
724+
batch_size = 3
725+
726+
router = BatchRouter(
727+
RoundRobinRouter(),
728+
batch_max_size=batch_size,
729+
batch_max_wait_s=0.5, # long enough to not trigger timeout
730+
)
731+
732+
# Enqueue `batch_size + 1` requests to force batch flush
733+
tasks = [
734+
asyncio.create_task(router.get_replica(replicas)) for _ in range(batch_size + 1)
735+
]
736+
results = await asyncio.gather(*tasks)
737+
738+
# Check all results are healthy replicas
739+
assert all(r.state == ReplicaState.HEALTHY for r in results)
740+
741+
# Check results only use existing replica indices
742+
indices = {r.idx for r in results}
743+
assert indices.issubset({0, 1})
744+
745+
# Ensure batch queue is empty after flush
746+
assert router._queue.qsize() == 0
747+
748+
749+
@pytest.mark.asyncio
750+
async def test_batching_router_skips_unhealthy_replicas():
751+
"""If a replica becomes unhealthy before batch dispatch, it should be skipped."""
752+
replicas = [make_replica(0, load=0), make_replica(1, load=10)]
753+
754+
router = BatchRouter(
755+
LeastLoadedRouter(),
756+
batch_max_size=4,
757+
batch_max_wait_s=0.5,
758+
)
759+
760+
# Start two requests that will form a batch
761+
tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(2)]
762+
763+
# While they are waiting, mark replica 0 (least loaded) as unhealthy
764+
await asyncio.sleep(0.01)
765+
replicas[0].state = ReplicaState.UNHEALTHY
766+
767+
results = await asyncio.gather(*tasks)
768+
769+
# All results must be the *healthy* replica (idx=1)
770+
assert all(r.idx == 1 for r in results)
771+
assert results[0].state == ReplicaState.HEALTHY
772+
773+
774+
@pytest.mark.asyncio
775+
async def test_batching_router_two_batches_timing():
776+
"""Test that two sequential batches are processed independently with proper timing."""
777+
import time
778+
779+
replicas = [make_replica(0, load=5), make_replica(1, load=10)]
780+
batch_wait_time = 0.05 # 50ms timeout
781+
782+
router = BatchRouter(
783+
LeastLoadedRouter(),
784+
batch_max_size=3,
785+
batch_max_wait_s=batch_wait_time,
786+
)
787+
788+
# First batch: 2 requests that will timeout
789+
start_time = time.time()
790+
791+
# Create first batch tasks
792+
first_batch_tasks = [
793+
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
794+
]
795+
796+
# Wait for first batch to complete (should timeout after batch_wait_time)
797+
first_results = await asyncio.gather(*first_batch_tasks)
798+
first_batch_duration = time.time() - start_time
799+
800+
# Verify first batch took approximately the timeout duration (tighter tolerance)
801+
assert (
802+
batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
803+
) # 10ms tolerance on 50ms timeout
804+
805+
# Verify first batch results (should pick lowest load replica)
806+
assert all(r.idx == 0 for r in first_results) # replica 0 has lower load
807+
assert all(r.state == ReplicaState.HEALTHY for r in first_results)
808+
809+
# Second batch: 2 more requests (new timing cycle should start)
810+
second_batch_start = time.time()
811+
812+
# Create second batch tasks
813+
second_batch_tasks = [
814+
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
815+
]
816+
817+
# Wait for second batch to complete
818+
second_results = await asyncio.gather(*second_batch_tasks)
819+
second_batch_duration = time.time() - second_batch_start
820+
821+
# Verify second batch also took approximately the timeout duration (tighter tolerance)
822+
assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
823+
824+
# Verify second batch results
825+
assert all(r.idx == 0 for r in second_results) # should still pick lowest load
826+
assert all(r.state == ReplicaState.HEALTHY for r in second_results)
827+
828+
# Ensure batch queue is empty after both batches
829+
assert router._queue.qsize() == 0
830+
831+
719832
# Router integeration tests
720833

721834

0 commit comments

Comments
 (0)