Skip to content

Commit 1503fc5

Browse files
committed
Merge branch 'main' into batch_router
2 parents e3aabbd + 30945ff commit 1503fc5

File tree

5 files changed

+236
-14
lines changed

5 files changed

+236
-14
lines changed

src/forge/controller/service/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .interface import ServiceInterface, Session, SessionContext
88
from .metrics import ServiceMetrics
99
from .replica import Replica, ReplicaMetrics, ReplicaState
10-
from .router import LeastLoadedRouter, RoundRobinRouter, SessionRouter
10+
from .router import BatchRouter, LeastLoadedRouter, RoundRobinRouter, SessionRouter
1111
from .service import Service, ServiceActor, ServiceConfig
1212

1313
__all__ = [
@@ -24,4 +24,5 @@
2424
"LeastLoadedRouter",
2525
"RoundRobinRouter",
2626
"SessionRouter",
27+
"BatchRouter",
2728
]

src/forge/controller/service/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ class Router(ABC):
286286
"""Abstract base class for routing logic."""
287287

288288
@abstractmethod
289-
def get_replica(
289+
async def get_replica(
290290
self,
291291
healthy_replicas: List[Replica],
292292
sess_id: str | None = None,

src/forge/controller/service/router.py

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import asyncio
78
import logging
8-
from typing import Dict, List
9+
import time
10+
from typing import Dict, List, Optional
911

1012
from .interface import Router
1113
from .replica import Replica
@@ -20,7 +22,7 @@ class RoundRobinRouter(Router):
2022
def __init__(self):
2123
self._next_idx = 0
2224

23-
def get_replica(
25+
async def get_replica(
2426
self,
2527
healthy_replicas: List[Replica],
2628
sess_id: str | None = None,
@@ -38,7 +40,7 @@ def get_replica(
3840
class LeastLoadedRouter(Router):
3941
"""Always routes to the replica with the lowest current load."""
4042

41-
def get_replica(
43+
async def get_replica(
4244
self,
4345
healthy_replicas: List[Replica],
4446
sess_id: str | None = None,
@@ -55,7 +57,7 @@ class SessionRouter(Router):
5557
def __init__(self, fallback_router: Router):
5658
self.fallback_router = fallback_router
5759

58-
def get_replica(
60+
async def get_replica(
5961
self,
6062
healthy_replicas: List[Replica],
6163
sess_id: str | None = None,
@@ -78,7 +80,7 @@ def get_replica(
7880
del session_map[sess_id]
7981

8082
# Use fallback router to assign a new replica
81-
replica = self.fallback_router.get_replica(
83+
replica = await self.fallback_router.get_replica(
8284
healthy_replicas, sess_id, session_map
8385
)
8486
session_map[sess_id] = replica.idx
@@ -88,3 +90,111 @@ def get_replica(
8890
replica.idx,
8991
)
9092
return replica
93+
94+
class BatchRouter(Router):
95+
"""
96+
Router wrapper that batches routing decisions.
97+
Uses an inner router to pick the replica for each batch.
98+
99+
Args:
100+
inner_router: The underlying Router instance used to make routing decisions
101+
batch_max_size: Maximum number of requests to collect in a single batch (default: 8)
102+
batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01)
103+
104+
Example:
105+
rr_router = RoundRobinRouter()
106+
batch_router = BatchRouter(rr_router, batch_max_size=16, batch_max_wait_s=0.02)
107+
108+
replica = await batch_router.get_replica(healthy_replicas, sess_id, session_map)
109+
"""
110+
111+
def __init__(
112+
self,
113+
inner_router: Router,
114+
batch_max_size: int = 8,
115+
batch_max_wait_s: float = 0.01,
116+
):
117+
118+
self.inner_router = inner_router
119+
self.batch_max_size = batch_max_size
120+
self.batch_max_wait_s = batch_max_wait_s
121+
122+
# Internal queue for batching routing requests
123+
self._queue: asyncio.Queue = asyncio.Queue()
124+
# Background task that processes batches continuously
125+
self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop())
126+
127+
async def _batch_loop(self):
128+
"""Background task that continuously processes batches of routing requests.
129+
130+
This is the core batching logic that runs in a separate asyncio task.
131+
It collects requests from the queue and processes them in batches based
132+
on size and time constraints.
133+
134+
The loop follows these steps:
135+
1. Wait for the first request to start a new batch
136+
2. Collect additional requests until batch_max_size or batch_max_wait_s is reached
137+
3. Make a single routing decision for the entire batch
138+
4. Fulfill all futures with the selected replica
139+
140+
This process repeats indefinitely until the task is cancelled.
141+
"""
142+
while True:
143+
batch = []
144+
futs = []
145+
sess_ids = []
146+
start_time = time.time()
147+
148+
# Wait for first request
149+
fut, healthy_replicas, sess_id, session_map = await self._queue.get()
150+
batch.append((healthy_replicas, sess_id, session_map))
151+
futs.append(fut)
152+
sess_ids.append(sess_id)
153+
154+
while True:
155+
try:
156+
timeout = max(0, self.batch_max_wait_s - (time.time() - start_time))
157+
(
158+
fut,
159+
healthy_replicas,
160+
sess_id,
161+
session_map,
162+
) = await asyncio.wait_for(self._queue.get(), timeout)
163+
batch.append((healthy_replicas, sess_id, session_map))
164+
futs.append(fut)
165+
sess_ids.append(sess_id)
166+
167+
if len(batch) >= self.batch_max_size:
168+
break
169+
except asyncio.TimeoutError:
170+
break
171+
172+
# One routing decision for the whole batch
173+
healthy_replicas = batch[-1][0] # use most recent replica state
174+
session_map = batch[-1][2] # use most recent session map
175+
176+
# Check if any replicas have become unhealthy
177+
healthy_replicas = [r for r in healthy_replicas if r.healthy]
178+
replica = await self.inner_router.get_replica(
179+
healthy_replicas, None, session_map
180+
)
181+
182+
# Fulfill all futures with the chosen replica
183+
for fut in futs:
184+
fut.set_result(replica)
185+
186+
async def get_replica(
187+
self,
188+
healthy_replicas: List[Replica],
189+
sess_id: Optional[str] = None,
190+
session_map: Optional[Dict[str, int]] = None,
191+
) -> Replica:
192+
"""Enqueue request and wait until batch assigns a replica."""
193+
loop = asyncio.get_event_loop()
194+
fut = loop.create_future()
195+
196+
# Queue the request for batching - this is non-blocking
197+
self._queue.put_nowait((fut, healthy_replicas, sess_id, session_map))
198+
199+
# Wait for the batch processor to resolve our future
200+
return await fut

src/forge/controller/service/service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ async def _get_replica(self, sess_id: str | None) -> "Replica":
477477
healthy_replicas = [r for r in self._replicas if r.healthy]
478478
if sess_id is None:
479479
# No session, use the default router
480-
return self._default_router.get_replica(healthy_replicas)
480+
return await self._default_router.get_replica(healthy_replicas)
481481

482-
return self._session_router.get_replica(
482+
return await self._session_router.get_replica(
483483
healthy_replicas, sess_id, self._session_replica_map
484484
)
485485

tests/unit_tests/test_service.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytest
1515
from forge.controller import ForgeActor
1616
from forge.controller.service import (
17+
BatchRouter,
1718
LeastLoadedRouter,
1819
Replica,
1920
ReplicaState,
@@ -666,8 +667,8 @@ async def test_session_router_with_round_robin_fallback():
666667
fallback = RoundRobinRouter()
667668
router = SessionRouter(fallback)
668669

669-
r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map)
670-
r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map)
670+
r1 = await router.get_replica(replicas, sess_id="sess1", session_map=session_map)
671+
r2 = await router.get_replica(replicas, sess_id="sess2", session_map=session_map)
671672

672673
assert r1.idx != r2.idx
673674
assert set(session_map.values()) == {0, 1}
@@ -678,11 +679,121 @@ async def test_session_router_with_round_robin_fallback():
678679
fallback = LeastLoadedRouter()
679680
router = SessionRouter(fallback)
680681

681-
r1 = router.get_replica(replicas, sess_id="sess1", session_map=session_map)
682-
r2 = router.get_replica(replicas, sess_id="sess2", session_map=session_map)
682+
r1 = await router.get_replica(replicas, sess_id="sess1", session_map=session_map)
683+
r2 = await router.get_replica(replicas, sess_id="sess2", session_map=session_map)
683684

684685
assert r1.idx == r2.idx == 0
685686

687+
@pytest.mark.asyncio
688+
async def test_batching_router_batchsize_with_roundrobin():
689+
"""Batch should flush when max batch size is reached using RoundRobinRouter."""
690+
replicas = [make_replica(0), make_replica(1)]
691+
batch_size = 3
692+
693+
router = BatchRouter(
694+
RoundRobinRouter(),
695+
batch_max_size=batch_size,
696+
batch_max_wait_s=0.5, # long enough to not trigger timeout
697+
)
698+
699+
# Enqueue `batch_size + 1` requests to force batch flush
700+
tasks = [
701+
asyncio.create_task(router.get_replica(replicas)) for _ in range(batch_size + 1)
702+
]
703+
results = await asyncio.gather(*tasks)
704+
705+
# Check all results are healthy replicas
706+
assert all(r.state == ReplicaState.HEALTHY for r in results)
707+
708+
# Check results only use existing replica indices
709+
indices = {r.idx for r in results}
710+
assert indices.issubset({0, 1})
711+
712+
# Ensure batch queue is empty after flush
713+
assert router._queue.qsize() == 0
714+
715+
716+
@pytest.mark.asyncio
717+
async def test_batching_router_skips_unhealthy_replicas():
718+
"""If a replica becomes unhealthy before batch dispatch, it should be skipped."""
719+
replicas = [make_replica(0, load=0), make_replica(1, load=10)]
720+
721+
router = BatchRouter(
722+
LeastLoadedRouter(),
723+
batch_max_size=4,
724+
batch_max_wait_s=0.5,
725+
)
726+
727+
# Start two requests that will form a batch
728+
tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(2)]
729+
730+
# While they are waiting, mark replica 0 (least loaded) as unhealthy
731+
await asyncio.sleep(0.01)
732+
replicas[0].state = ReplicaState.UNHEALTHY
733+
734+
results = await asyncio.gather(*tasks)
735+
736+
# All results must be the *healthy* replica (idx=1)
737+
assert all(r.idx == 1 for r in results)
738+
assert results[0].state == ReplicaState.HEALTHY
739+
740+
741+
@pytest.mark.asyncio
742+
async def test_batching_router_two_batches_timing():
743+
"""Test that two sequential batches are processed independently with proper timing."""
744+
import time
745+
746+
replicas = [make_replica(0, load=5), make_replica(1, load=10)]
747+
batch_wait_time = 0.05 # 50ms timeout
748+
749+
router = BatchRouter(
750+
LeastLoadedRouter(),
751+
batch_max_size=3,
752+
batch_max_wait_s=batch_wait_time,
753+
)
754+
755+
# First batch: 2 requests that will timeout
756+
start_time = time.time()
757+
758+
# Create first batch tasks
759+
first_batch_tasks = [
760+
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
761+
]
762+
763+
# Wait for first batch to complete (should timeout after batch_wait_time)
764+
first_results = await asyncio.gather(*first_batch_tasks)
765+
first_batch_duration = time.time() - start_time
766+
767+
# Verify first batch took approximately the timeout duration (tighter tolerance)
768+
assert (
769+
batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
770+
) # 10ms tolerance on 50ms timeout
771+
772+
# Verify first batch results (should pick lowest load replica)
773+
assert all(r.idx == 0 for r in first_results) # replica 0 has lower load
774+
assert all(r.state == ReplicaState.HEALTHY for r in first_results)
775+
776+
# Second batch: 2 more requests (new timing cycle should start)
777+
second_batch_start = time.time()
778+
779+
# Create second batch tasks
780+
second_batch_tasks = [
781+
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
782+
]
783+
784+
# Wait for second batch to complete
785+
second_results = await asyncio.gather(*second_batch_tasks)
786+
second_batch_duration = time.time() - second_batch_start
787+
788+
# Verify second batch also took approximately the timeout duration (tighter tolerance)
789+
assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
790+
791+
# Verify second batch results
792+
assert all(r.idx == 0 for r in second_results) # should still pick lowest load
793+
assert all(r.state == ReplicaState.HEALTHY for r in second_results)
794+
795+
# Ensure batch queue is empty after both batches
796+
assert router._queue.qsize() == 0
686797

687798
# Router integeration tests
688799

@@ -743,4 +854,4 @@ async def test_session_router_assigns_and_updates_session_map_in_service():
743854
assert values2[assigned_idx] == values1[assigned_idx] + 1
744855

745856
finally:
746-
await service.shutdown()
857+
await service.shutdown()

0 commit comments

Comments
 (0)