44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import asyncio
78import logging
8- from typing import Dict , List
9+ import time
10+ from typing import Dict , List , Optional
911
1012from .interface import Router
1113from .replica import Replica
@@ -20,7 +22,7 @@ class RoundRobinRouter(Router):
2022 def __init__ (self ):
2123 self ._next_idx = 0
2224
23- def get_replica (
25+ async def get_replica (
2426 self ,
2527 healthy_replicas : List [Replica ],
2628 sess_id : str | None = None ,
@@ -38,7 +40,7 @@ def get_replica(
3840class LeastLoadedRouter (Router ):
3941 """Always routes to the replica with the lowest current load."""
4042
41- def get_replica (
43+ async def get_replica (
4244 self ,
4345 healthy_replicas : List [Replica ],
4446 sess_id : str | None = None ,
@@ -55,7 +57,7 @@ class SessionRouter(Router):
5557 def __init__ (self , fallback_router : Router ):
5658 self .fallback_router = fallback_router
5759
58- def get_replica (
60+ async def get_replica (
5961 self ,
6062 healthy_replicas : List [Replica ],
6163 sess_id : str | None = None ,
@@ -78,7 +80,7 @@ def get_replica(
7880 del session_map [sess_id ]
7981
8082 # Use fallback router to assign a new replica
81- replica = self .fallback_router .get_replica (
83+ replica = await self .fallback_router .get_replica (
8284 healthy_replicas , sess_id , session_map
8385 )
8486 session_map [sess_id ] = replica .idx
@@ -88,3 +90,111 @@ def get_replica(
8890 replica .idx ,
8991 )
9092 return replica
93+
94+ class BatchRouter (Router ):
95+ """
96+ Router wrapper that batches routing decisions.
97+ Uses an inner router to pick the replica for each batch.
98+
99+ Args:
100+ inner_router: The underlying Router instance used to make routing decisions
101+ batch_max_size: Maximum number of requests to collect in a single batch (default: 8)
102+ batch_max_wait_s: Maximum time to wait before processing a batch in seconds (default: 0.01)
103+
104+ Example:
105+ rr_router = RoundRobinRouter()
106+ batch_router = BatchRouter(rr_router, batch_max_size=16, batch_max_wait_s=0.02)
107+
108+ replica = await batch_router.get_replica(healthy_replicas, sess_id, session_map)
109+ """
110+
111+ def __init__ (
112+ self ,
113+ inner_router : Router ,
114+ batch_max_size : int = 8 ,
115+ batch_max_wait_s : float = 0.01 ,
116+ ):
117+
118+ self .inner_router = inner_router
119+ self .batch_max_size = batch_max_size
120+ self .batch_max_wait_s = batch_max_wait_s
121+
122+ # Internal queue for batching routing requests
123+ self ._queue : asyncio .Queue = asyncio .Queue ()
124+ # Background task that processes batches continuously
125+ self ._batch_task : asyncio .Task = asyncio .create_task (self ._batch_loop ())
126+
127+ async def _batch_loop (self ):
128+ """Background task that continuously processes batches of routing requests.
129+
130+ This is the core batching logic that runs in a separate asyncio task.
131+ It collects requests from the queue and processes them in batches based
132+ on size and time constraints.
133+
134+ The loop follows these steps:
135+ 1. Wait for the first request to start a new batch
136+ 2. Collect additional requests until batch_max_size or batch_max_wait_s is reached
137+ 3. Make a single routing decision for the entire batch
138+ 4. Fulfill all futures with the selected replica
139+
140+ This process repeats indefinitely until the task is cancelled.
141+ """
142+ while True :
143+ batch = []
144+ futs = []
145+ sess_ids = []
146+ start_time = time .time ()
147+
148+ # Wait for first request
149+ fut , healthy_replicas , sess_id , session_map = await self ._queue .get ()
150+ batch .append ((healthy_replicas , sess_id , session_map ))
151+ futs .append (fut )
152+ sess_ids .append (sess_id )
153+
154+ while True :
155+ try :
156+ timeout = max (0 , self .batch_max_wait_s - (time .time () - start_time ))
157+ (
158+ fut ,
159+ healthy_replicas ,
160+ sess_id ,
161+ session_map ,
162+ ) = await asyncio .wait_for (self ._queue .get (), timeout )
163+ batch .append ((healthy_replicas , sess_id , session_map ))
164+ futs .append (fut )
165+ sess_ids .append (sess_id )
166+
167+ if len (batch ) >= self .batch_max_size :
168+ break
169+ except asyncio .TimeoutError :
170+ break
171+
172+ # One routing decision for the whole batch
173+ healthy_replicas = batch [- 1 ][0 ] # use most recent replica state
174+ session_map = batch [- 1 ][2 ] # use most recent session map
175+
176+ # Check if any replicas have become unhealthy
177+ healthy_replicas = [r for r in healthy_replicas if r .healthy ]
178+ replica = await self .inner_router .get_replica (
179+ healthy_replicas , None , session_map
180+ )
181+
182+ # Fulfill all futures with the chosen replica
183+ for fut in futs :
184+ fut .set_result (replica )
185+
186+ async def get_replica (
187+ self ,
188+ healthy_replicas : List [Replica ],
189+ sess_id : Optional [str ] = None ,
190+ session_map : Optional [Dict [str , int ]] = None ,
191+ ) -> Replica :
192+ """Enqueue request and wait until batch assigns a replica."""
193+ loop = asyncio .get_event_loop ()
194+ fut = loop .create_future ()
195+
196+ # Queue the request for batching - this is non-blocking
197+ self ._queue .put_nowait ((fut , healthy_replicas , sess_id , session_map ))
198+
199+ # Wait for the batch processor to resolve our future
200+ return await fut
0 commit comments