@@ -686,6 +686,7 @@ async def test_session_router_with_round_robin_fallback():
686686
687687 assert r1 .idx == r2 .idx == 0
688688
689+
689690@pytest .mark .asyncio
690691async def test_batching_router_batchsize_with_roundrobin ():
691692 """Batch should flush when max batch size is reached using RoundRobinRouter."""
@@ -698,21 +699,25 @@ async def test_batching_router_batchsize_with_roundrobin():
698699 batch_max_wait_s = 0.5 , # long enough to not trigger timeout
699700 )
700701
701- # Enqueue `batch_size + 1` requests to force batch flush
702- tasks = [
703- asyncio .create_task (router .get_replica (replicas )) for _ in range (batch_size + 1 )
704- ]
705- results = await asyncio .gather (* tasks )
702+ try :
703+ # Enqueue `batch_size + 1` requests to force batch flush
704+ tasks = [
705+ asyncio .create_task (router .get_replica (replicas ))
706+ for _ in range (batch_size + 1 )
707+ ]
708+ results = await asyncio .gather (* tasks )
706709
707- # Check all results are healthy replicas
708- assert all (r .state == ReplicaState .HEALTHY for r in results )
710+ # Check all results are healthy replicas
711+ assert all (r .state == ReplicaState .HEALTHY for r in results )
709712
710- # Check results only use existing replica indices
711- indices = {r .idx for r in results }
712- assert indices .issubset ({0 , 1 })
713+ # Check results only use existing replica indices
714+ indices = {r .idx for r in results }
715+ assert indices .issubset ({0 , 1 })
713716
714- # Ensure batch queue is empty after flush
715- assert router ._queue .qsize () == 0
717+ # Ensure batch queue is empty after flush
718+ assert router ._queue .qsize () == 0
719+ finally :
720+ router .shutdown ()
716721
717722
718723@pytest .mark .asyncio
@@ -725,19 +730,21 @@ async def test_batching_router_skips_unhealthy_replicas():
725730 batch_max_size = 4 ,
726731 batch_max_wait_s = 0.5 ,
727732 )
733+ try :
734+ # Start two requests that will form a batch
735+ tasks = [asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )]
728736
729- # Start two requests that will form a batch
730- tasks = [asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )]
731-
732- # While they are waiting, mark replica 0 (least loaded) as unhealthy
733- await asyncio .sleep (0.01 )
734- replicas [0 ].state = ReplicaState .UNHEALTHY
737+ # While they are waiting, mark replica 0 (least loaded) as unhealthy
738+ await asyncio .sleep (0.01 )
739+ replicas [0 ].state = ReplicaState .UNHEALTHY
735740
736- results = await asyncio .gather (* tasks )
741+ results = await asyncio .gather (* tasks )
737742
738- # All results must be the *healthy* replica (idx=1)
739- assert all (r .idx == 1 for r in results )
740- assert results [0 ].state == ReplicaState .HEALTHY
743+ # All results must be the *healthy* replica (idx=1)
744+ assert all (r .idx == 1 for r in results )
745+ assert results [0 ].state == ReplicaState .HEALTHY
746+ finally :
747+ router .shutdown ()
741748
742749
743750@pytest .mark .asyncio
@@ -753,49 +760,52 @@ async def test_batching_router_two_batches_timing():
753760 batch_max_size = 3 ,
754761 batch_max_wait_s = batch_wait_time ,
755762 )
763+ try :
764+ # First batch: 2 requests that will timeout
765+ start_time = time .time ()
756766
757- # First batch: 2 requests that will timeout
758- start_time = time .time ()
767+ # Create first batch tasks
768+ first_batch_tasks = [
769+ asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )
770+ ]
759771
760- # Create first batch tasks
761- first_batch_tasks = [
762- asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )
763- ]
772+ # Wait for first batch to complete (should timeout after batch_wait_time)
773+ first_results = await asyncio .gather (* first_batch_tasks )
774+ first_batch_duration = time .time () - start_time
764775
765- # Wait for first batch to complete (should timeout after batch_wait_time)
766- first_results = await asyncio .gather (* first_batch_tasks )
767- first_batch_duration = time .time () - start_time
776+ # Verify first batch took approximately the timeout duration (tighter tolerance)
777+ assert (
778+ batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
779+ ) # 10ms tolerance on 50ms timeout
768780
769- # Verify first batch took approximately the timeout duration (tighter tolerance)
770- assert (
771- batch_wait_time <= first_batch_duration < batch_wait_time + 0.01
772- ) # 10ms tolerance on 50ms timeout
781+ # Verify first batch results (should pick lowest load replica)
782+ assert all (r .idx == 0 for r in first_results ) # replica 0 has lower load
783+ assert all (r .state == ReplicaState .HEALTHY for r in first_results )
773784
774- # Verify first batch results (should pick lowest load replica)
775- assert all (r .idx == 0 for r in first_results ) # replica 0 has lower load
776- assert all (r .state == ReplicaState .HEALTHY for r in first_results )
785+ # Second batch: 2 more requests (new timing cycle should start)
786+ second_batch_start = time .time ()
777787
778- # Second batch: 2 more requests (new timing cycle should start)
779- second_batch_start = time .time ()
788+ # Create second batch tasks
789+ second_batch_tasks = [
790+ asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )
791+ ]
780792
781- # Create second batch tasks
782- second_batch_tasks = [
783- asyncio .create_task (router .get_replica (replicas )) for _ in range (2 )
784- ]
793+ # Wait for second batch to complete
794+ second_results = await asyncio .gather (* second_batch_tasks )
795+ second_batch_duration = time .time () - second_batch_start
785796
786- # Wait for second batch to complete
787- second_results = await asyncio .gather (* second_batch_tasks )
788- second_batch_duration = time .time () - second_batch_start
797+ # Verify second batch also took approximately the timeout duration (tighter tolerance)
798+ assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
789799
790- # Verify second batch also took approximately the timeout duration (tighter tolerance)
791- assert batch_wait_time <= second_batch_duration < batch_wait_time + 0.01
800+ # Verify second batch results
801+ assert all (r .idx == 0 for r in second_results ) # should still pick lowest load
802+ assert all (r .state == ReplicaState .HEALTHY for r in second_results )
792803
793- # Verify second batch results
794- assert all (r .idx == 0 for r in second_results ) # should still pick lowest load
795- assert all (r .state == ReplicaState .HEALTHY for r in second_results )
804+ # Ensure batch queue is empty after both batches
805+ assert router ._queue .qsize () == 0
806+ finally :
807+ router .shutdown ()
796808
797- # Ensure batch queue is empty after both batches
798- assert router ._queue .qsize () == 0
799809
800810# Router integeration tests
801811
@@ -856,4 +866,4 @@ async def test_session_router_assigns_and_updates_session_map_in_service():
856866 assert values2 [assigned_idx ] == values1 [assigned_idx ] + 1
857867
858868 finally :
859- await service .shutdown ()
869+ await service .shutdown ()
0 commit comments