1414import pytest
1515from forge .controller import ForgeActor
1616from forge .controller .service import (
17+ BatchRouter ,
1718 LeastLoadedRouter ,
1819 Replica ,
1920 ReplicaState ,
@@ -668,7 +669,7 @@ async def test_least_loaded_router_basic():
668669 make_replica (2 , load = 3 ),
669670 ]
670671 router = LeastLoadedRouter ()
671- chosen = router .get_replica (replicas )
672+ chosen = await router .get_replica (replicas )
672673 assert chosen .idx == 1 # lowest load
673674
674675
@@ -681,11 +682,11 @@ async def test_session_router_assigns_and_updates_session_map():
681682 router = SessionRouter (fallback )
682683
683684 # First request assigns via fallback
684- r1 = router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
685+ r1 = await router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
685686 assert session_map ["sess1" ] == r1 .idx
686687
687688 # Second request should stick
688- r2 = router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
689+ r2 = await router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
689690 assert r1 .idx == r2 .idx
690691
691692
@@ -698,8 +699,8 @@ async def test_session_router_with_round_robin_fallback():
698699 fallback = RoundRobinRouter ()
699700 router = SessionRouter (fallback )
700701
701- r1 = router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
702- r2 = router .get_replica (replicas , sess_id = "sess2" , session_map = session_map )
702+ r1 = await router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
703+ r2 = await router .get_replica (replicas , sess_id = "sess2" , session_map = session_map )
703704
704705 assert r1 .idx != r2 .idx
705706 assert set (session_map .values ()) == {0 , 1 }
@@ -710,12 +711,124 @@ async def test_session_router_with_round_robin_fallback():
710711 fallback = LeastLoadedRouter ()
711712 router = SessionRouter (fallback )
712713
713- r1 = router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
714- r2 = router .get_replica (replicas , sess_id = "sess2" , session_map = session_map )
714+ r1 = await router .get_replica (replicas , sess_id = "sess1" , session_map = session_map )
715+ r2 = await router .get_replica (replicas , sess_id = "sess2" , session_map = session_map )
715716
716717 assert r1 .idx == r2 .idx == 0
717718
718719
720+ @pytest .mark .asyncio
721+ async def test_batching_router_batchsize_with_roundrobin ():
722+ """Batch should flush when max batch size is reached using RoundRobinRouter."""
723+ replicas = [make_replica (0 ), make_replica (1 )]
724+ batch_size = 3
725+
726+ router = BatchRouter (
727+ RoundRobinRouter (),
728+ batch_max_size = batch_size ,
729+ batch_max_wait_s = 0.5 , # long enough to not trigger timeout
730+ )
731+
732+ # Enqueue `batch_size + 1` requests to force batch flush
733+ tasks = [
734+ asyncio .create_task (router .get_replica (replicas )) for _ in range (batch_size + 1 )
735+ ]
736+ results = await asyncio .gather (* tasks )
737+
738+ # Check all results are healthy replicas
739+ assert all (r .state == ReplicaState .HEALTHY for r in results )
740+
741+ # Check results only use existing replica indices
742+ indices = {r .idx for r in results }
743+ assert indices .issubset ({0 , 1 })
744+
745+ # Ensure batch queue is empty after flush
746+ assert router ._queue .qsize () == 0
747+
748+
749+ @pytest .mark .asyncio
750+ async def test_batching_router_skips_unhealthy_replicas ():
751+ """If a replica becomes unhealthy before batch dispatch, it should be skipped."""
752+ replicas = [make_replica (0 , load = 0 ), make_replica (1 , load = 10 )]
753+
754+ router = BatchRouter (
755+ LeastLoadedRouter (),
756+ batch_max_size = 4 ,
757+ batch_max_wait_s = 0.5 ,
758+ )
759+
760+ # Start two requests that will form a batch
761+ tasks = [asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )]
762+
763+ # While they are waiting, mark replica 0 (least loaded) as unhealthy
764+ await asyncio .sleep (0.01 )
765+ replicas [0 ].state = ReplicaState .UNHEALTHY
766+
767+ results = await asyncio .gather (* tasks )
768+
769+ # All results must be the *healthy* replica (idx=1)
770+ assert all (r .idx == 1 for r in results )
771+ assert results [0 ].state == ReplicaState .HEALTHY
772+
773+
774+ @pytest .mark .asyncio
775+ async def test_batching_router_two_batches_timing ():
776+ """Test that two sequential batches are processed independently with proper timing."""
777+ import time
778+
779+ replicas = [make_replica (0 , load = 5 ), make_replica (1 , load = 10 )]
780+ batch_wait_time = 0.05 # 50ms timeout
781+
782+ router = BatchRouter (
783+ LeastLoadedRouter (),
784+ batch_max_size = 3 ,
785+ batch_max_wait_s = batch_wait_time ,
786+ )
787+
788+ # First batch: 2 requests that will timeout
789+ start_time = time .time ()
790+
791+ # Create first batch tasks
792+ first_batch_tasks = [
793+ asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )
794+ ]
795+
796+ # Wait for first batch to complete (should timeout after batch_wait_time)
797+ first_results = await asyncio .gather (* first_batch_tasks )
798+ first_batch_duration = time .time () - start_time
799+
800+ # Verify first batch took approximately the timeout duration (tighter tolerance)
801+ assert (
802+ batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
803+ ) # 10ms tolerance on 50ms timeout
804+
805+ # Verify first batch results (should pick lowest load replica)
806+ assert all (r .idx == 0 for r in first_results ) # replica 0 has lower load
807+ assert all (r .state == ReplicaState .HEALTHY for r in first_results )
808+
809+ # Second batch: 2 more requests (new timing cycle should start)
810+ second_batch_start = time .time ()
811+
812+ # Create second batch tasks
813+ second_batch_tasks = [
814+ asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )
815+ ]
816+
817+ # Wait for second batch to complete
818+ second_results = await asyncio .gather (* second_batch_tasks )
819+ second_batch_duration = time .time () - second_batch_start
820+
821+ # Verify second batch also took approximately the timeout duration (tighter tolerance)
822+ assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
823+
824+ # Verify second batch results
825+ assert all (r .idx == 0 for r in second_results ) # should still pick lowest load
826+ assert all (r .state == ReplicaState .HEALTHY for r in second_results )
827+
828+ # Ensure batch queue is empty after both batches
829+ assert router ._queue .qsize () == 0
830+
831+
719832# Router integeration tests
720833
721834
0 commit comments