-
Notifications
You must be signed in to change notification settings - Fork 15
Add Batch routing support via @service_endpoint
with configurable batch size and timeout
#304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Instead of selecting a replica immediately, incoming requests are enqueued | ||
and grouped into batches. Once a batch is ready (either reaching the maximum | ||
size or exceeding the maximum wait time), the batcher makes a single routing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to base the wait time on the status of the replica instead of a fixed time? For example, if the replica is still busy, we can let the batch grow larger, but if the replica is free for some minimum time interval, then we can send the batch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. We could make the batch timeout adaptive based on replica load (e.g., wait longer when replicas are busy and flush earlier when they’re idle). I’d prefer to land this current version first, then explore that as a follow-up improvement once the base batching logic is stable. Just added a TODO in the while loop.
# TODO: make timeout adaptive based on replica load.
@service_endpoint
with configurable batch size and timeout@service_endpoint
with configurable batch size and timeout
session_id=None, | ||
function=self.function, | ||
args=args, | ||
kwargs={}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 Do we want to support kwargs here?
results = [results] * len(batch) | ||
else: | ||
# scalar result, broadcast to batch size | ||
results = [results] * len(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 Do we want to handle when the returned results have different length or a scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think both the batching logic and the unbatching logic should be factored out.
The batcher should accept two functions
batch_fn :: [Request] => BatchRequest | Request
unbatch_fn :: BatchResponse | Response | => [Response]
self.service = service | ||
self.endpoint_name = endpoint_name | ||
|
||
async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to schedule the task to start immediately by making this return a future?
In particular we can fire and forget without await
ing it - this is actually what happens with monarch native actor api iiuc.
"Services only support route() and fanout()." | ||
) | ||
|
||
async def generate(self, *args: P.args, **kwargs: P.kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
"Services only support route() and fanout()." | ||
) | ||
|
||
async def generate(self, *args: P.args, **kwargs: P.kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actor also doesn't have generate
?
def __init__( | ||
self, | ||
function: str, | ||
inner_router: Router, | ||
get_healthy_replicas: Callable[[], List["Replica"]], | ||
get_session_map: Callable[[], Dict[str, int]], | ||
batch_size: int = 16, | ||
batch_timeout: float = 0.01, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__( | |
self, | |
function: str, | |
inner_router: Router, | |
get_healthy_replicas: Callable[[], List["Replica"]], | |
get_session_map: Callable[[], Dict[str, int]], | |
batch_size: int = 16, | |
batch_timeout: float = 0.01, | |
): | |
def __init__( | |
self, | |
function: str, | |
inner_router: Router, | |
*, | |
get_healthy_replicas: Callable[[], List["Replica"]], | |
get_session_map: Callable[[], Dict[str, int]], | |
batch_size: int = 16, | |
batch_timeout: float = 0.01, | |
): |
def get_replica( | ||
self, | ||
healthy_replicas: List[Replica], | ||
sess_id: str | None = None, | ||
session_map: Dict[str, int] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_replica( | |
self, | |
healthy_replicas: List[Replica], | |
sess_id: str | None = None, | |
session_map: Dict[str, int] | None = None, | |
def get_replica( | |
self, | |
healthy_replicas: List[Replica], | |
*, | |
sess_id: str | None = None, | |
session_map: Dict[str, int] | None = None, |
"""Add (args, kwargs) pair to queue, return a Future resolved when batch completes.""" | ||
# Queue the request for batching | ||
fut = asyncio.Future() | ||
self._queue.put_nowait((function, args, kwargs, fut)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why put_nowait
? if there is no capacity limit for the queue, there isn't much of a difference. if there is, this will throw an Error?
Args: | ||
inner_router: The underlying Router used to pick a replica. | ||
get_healthy_replicas: Callable that returns the current list of healthy replicas. | ||
get_session_map: Callable that returns the session-to-replica mapping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 this is a question for you. why doesn't the SessionRouter own the session_map in the first place? Instead, we have to pass this down layer by layer?
# One routing decision for the whole batch | ||
replica = self.inner_router.get_replica( | ||
healthy_replicas, None, session_map | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you only have one (batched) request in the first place, why do you need to keep track of the replica?
self._session_router = SessionRouter(fallback_router=self._default_router) | ||
|
||
# This keeps the map between the registered endpoints and the routers | ||
self.routers: dict[str, Router | Batcher] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think Batcher should be considered a router. See above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.routers: dict[str, Router | Batcher] = {} | |
self._routers: dict[str, Router] = {} | |
self._batchers: dict[str, Batcher] = {} |
assert set(session_map.values()) == {0, 1} | ||
|
||
# If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas | ||
replicas = [make_replica(0, load=0), make_replica(1, load=5)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please test for behaviors.
e.g. if there is a bug in the code so that self.active_requests
never gets incremented, this won't catch it.
also, if the implementation of LeastLoadedRouter gets changed, this test will unnecessarily break.
Migrated from #177
Context: #160
Add batch routing to
Service
to improve request throughput and maintain session-aware routing.Added new
@service_endpoint
decorator that supports routing configuration (router
,batch_size
,batch_timeout
).Introduced
ServiceEndpointProperty
to distinguish between@endpoint
and@service_endpoint
.Centralized endpoint-to-router mapping in
Service
(self.routers
) with support for both plain routers and batchers.Updated
ServiceInterface
to register endpoints through_set_router
, ensuring consistent setup for both standard and service endpoints.Extended
_call
and_get_replica
to handle batch routing, session routing, and fallback routing in a unified way.Enhanced
Service.stop
to gracefully shut down any active batchers in addition to replicas.Added integration tests to validate:
batch_size
is reachedTest