Skip to content

Commit c783a80

Browse files
committed
resolve comments
1 parent 432af71 commit c783a80

File tree

2 files changed

+86
-63
lines changed

2 files changed

+86
-63
lines changed

src/forge/controller/service/router.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ async def get_replica(
9191
)
9292
return replica
9393

94+
9495
class BatchRouter(Router):
9596
"""
9697
Router wrapper that batches routing decisions.
@@ -121,6 +122,7 @@ def __init__(
121122

122123
# Internal queue for batching routing requests
123124
self._queue: asyncio.Queue = asyncio.Queue()
125+
self._running = True # flag to control loop
124126
# Background task that processes batches continuously
125127
self._batch_task: asyncio.Task = asyncio.create_task(self._batch_loop())
126128

@@ -139,27 +141,31 @@ async def _batch_loop(self):
139141
140142
This process repeats indefinitely until the task is cancelled.
141143
"""
142-
while True:
144+
while self._running:
143145
batch = []
144146
futs = []
145147
sess_ids = []
146-
start_time = time.time()
147148

148149
# Wait for first request
149150
fut, healthy_replicas, sess_id, session_map = await self._queue.get()
150151
batch.append((healthy_replicas, sess_id, session_map))
151152
futs.append(fut)
152153
sess_ids.append(sess_id)
154+
start_time = time.monotonic()
153155

154156
while True:
155157
try:
156-
timeout = max(0, self.batch_max_wait_s - (time.time() - start_time))
158+
timeout = max(
159+
0, self.batch_max_wait_s - (time.monotonic() - start_time)
160+
)
157161
(
158162
fut,
159163
healthy_replicas,
160164
sess_id,
161165
session_map,
162-
) = await asyncio.wait_for(self._queue.get(), timeout)
166+
) = await asyncio.wait_for(
167+
self._queue.get(), timeout
168+
) # wait for timeout or until self._queue.get() finishes
163169
batch.append((healthy_replicas, sess_id, session_map))
164170
futs.append(fut)
165171
sess_ids.append(sess_id)
@@ -190,11 +196,18 @@ async def get_replica(
190196
session_map: Optional[Dict[str, int]] = None,
191197
) -> Replica:
192198
"""Enqueue request and wait until batch assigns a replica."""
193-
loop = asyncio.get_event_loop()
194-
fut = loop.create_future()
195-
199+
fut = asyncio.Future()
196200
# Queue the request for batching - this is non-blocking
197201
self._queue.put_nowait((fut, healthy_replicas, sess_id, session_map))
198202

199203
# Wait for the batch processor to resolve our future
200-
return await fut
204+
return await fut
205+
206+
async def shutdown(self):
207+
"""Stop the batch loop gracefully."""
208+
self._running = False
209+
self._batch_task.cancel()
210+
try:
211+
await self._batch_task
212+
except asyncio.CancelledError:
213+
pass

tests/unit_tests/test_service.py

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ async def test_session_router_with_round_robin_fallback():
686686

687687
assert r1.idx == r2.idx == 0
688688

689+
689690
@pytest.mark.asyncio
690691
async def test_batching_router_batchsize_with_roundrobin():
691692
"""Batch should flush when max batch size is reached using RoundRobinRouter."""
@@ -698,21 +699,25 @@ async def test_batching_router_batchsize_with_roundrobin():
698699
batch_max_wait_s=0.5, # long enough to not trigger timeout
699700
)
700701

701-
# Enqueue `batch_size + 1` requests to force batch flush
702-
tasks = [
703-
asyncio.create_task(router.get_replica(replicas)) for _ in range(batch_size + 1)
704-
]
705-
results = await asyncio.gather(*tasks)
702+
try:
703+
# Enqueue `batch_size + 1` requests to force batch flush
704+
tasks = [
705+
asyncio.create_task(router.get_replica(replicas))
706+
for _ in range(batch_size + 1)
707+
]
708+
results = await asyncio.gather(*tasks)
706709

707-
# Check all results are healthy replicas
708-
assert all(r.state == ReplicaState.HEALTHY for r in results)
710+
# Check all results are healthy replicas
711+
assert all(r.state == ReplicaState.HEALTHY for r in results)
709712

710-
# Check results only use existing replica indices
711-
indices = {r.idx for r in results}
712-
assert indices.issubset({0, 1})
713+
# Check results only use existing replica indices
714+
indices = {r.idx for r in results}
715+
assert indices.issubset({0, 1})
713716

714-
# Ensure batch queue is empty after flush
715-
assert router._queue.qsize() == 0
717+
# Ensure batch queue is empty after flush
718+
assert router._queue.qsize() == 0
719+
finally:
720+
router.shutdown()
716721

717722

718723
@pytest.mark.asyncio
@@ -725,19 +730,21 @@ async def test_batching_router_skips_unhealthy_replicas():
725730
batch_max_size=4,
726731
batch_max_wait_s=0.5,
727732
)
733+
try:
734+
# Start two requests that will form a batch
735+
tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(2)]
728736

729-
# Start two requests that will form a batch
730-
tasks = [asyncio.create_task(router.get_replica(replicas)) for _ in range(2)]
731-
732-
# While they are waiting, mark replica 0 (least loaded) as unhealthy
733-
await asyncio.sleep(0.01)
734-
replicas[0].state = ReplicaState.UNHEALTHY
737+
# While they are waiting, mark replica 0 (least loaded) as unhealthy
738+
await asyncio.sleep(0.01)
739+
replicas[0].state = ReplicaState.UNHEALTHY
735740

736-
results = await asyncio.gather(*tasks)
741+
results = await asyncio.gather(*tasks)
737742

738-
# All results must be the *healthy* replica (idx=1)
739-
assert all(r.idx == 1 for r in results)
740-
assert results[0].state == ReplicaState.HEALTHY
743+
# All results must be the *healthy* replica (idx=1)
744+
assert all(r.idx == 1 for r in results)
745+
assert results[0].state == ReplicaState.HEALTHY
746+
finally:
747+
router.shutdown()
741748

742749

743750
@pytest.mark.asyncio
@@ -753,49 +760,52 @@ async def test_batching_router_two_batches_timing():
753760
batch_max_size=3,
754761
batch_max_wait_s=batch_wait_time,
755762
)
763+
try:
764+
# First batch: 2 requests that will timeout
765+
start_time = time.time()
756766

757-
# First batch: 2 requests that will timeout
758-
start_time = time.time()
767+
# Create first batch tasks
768+
first_batch_tasks = [
769+
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
770+
]
759771

760-
# Create first batch tasks
761-
first_batch_tasks = [
762-
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
763-
]
772+
# Wait for first batch to complete (should timeout after batch_wait_time)
773+
first_results = await asyncio.gather(*first_batch_tasks)
774+
first_batch_duration = time.time() - start_time
764775

765-
# Wait for first batch to complete (should timeout after batch_wait_time)
766-
first_results = await asyncio.gather(*first_batch_tasks)
767-
first_batch_duration = time.time() - start_time
776+
# Verify first batch took approximately the timeout duration (tighter tolerance)
777+
assert (
778+
batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
779+
) # 10ms tolerance on 50ms timeout
768780

769-
# Verify first batch took approximately the timeout duration (tighter tolerance)
770-
assert (
771-
batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
772-
) # 10ms tolerance on 50ms timeout
781+
# Verify first batch results (should pick lowest load replica)
782+
assert all(r.idx == 0 for r in first_results) # replica 0 has lower load
783+
assert all(r.state == ReplicaState.HEALTHY for r in first_results)
773784

774-
# Verify first batch results (should pick lowest load replica)
775-
assert all(r.idx == 0 for r in first_results) # replica 0 has lower load
776-
assert all(r.state == ReplicaState.HEALTHY for r in first_results)
785+
# Second batch: 2 more requests (new timing cycle should start)
786+
second_batch_start = time.time()
777787

778-
# Second batch: 2 more requests (new timing cycle should start)
779-
second_batch_start = time.time()
788+
# Create second batch tasks
789+
second_batch_tasks = [
790+
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
791+
]
780792

781-
# Create second batch tasks
782-
second_batch_tasks = [
783-
asyncio.create_task(router.get_replica(replicas)) for _ in range(2)
784-
]
793+
# Wait for second batch to complete
794+
second_results = await asyncio.gather(*second_batch_tasks)
795+
second_batch_duration = time.time() - second_batch_start
785796

786-
# Wait for second batch to complete
787-
second_results = await asyncio.gather(*second_batch_tasks)
788-
second_batch_duration = time.time() - second_batch_start
797+
# Verify second batch also took approximately the timeout duration (tighter tolerance)
798+
assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
789799

790-
# Verify second batch also took approximately the timeout duration (tighter tolerance)
791-
assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
800+
# Verify second batch results
801+
assert all(r.idx == 0 for r in second_results) # should still pick lowest load
802+
assert all(r.state == ReplicaState.HEALTHY for r in second_results)
792803

793-
# Verify second batch results
794-
assert all(r.idx == 0 for r in second_results) # should still pick lowest load
795-
assert all(r.state == ReplicaState.HEALTHY for r in second_results)
804+
# Ensure batch queue is empty after both batches
805+
assert router._queue.qsize() == 0
806+
finally:
807+
router.shutdown()
796808

797-
# Ensure batch queue is empty after both batches
798-
assert router._queue.qsize() == 0
799809

800810
# Router integeration tests
801811

@@ -856,4 +866,4 @@ async def test_session_router_assigns_and_updates_session_map_in_service():
856866
assert values2[assigned_idx] == values1[assigned_idx] + 1
857867

858868
finally:
859-
await service.shutdown()
869+
await service.shutdown()

0 commit comments

Comments
 (0)