Skip to content

Conversation

DNXie
Copy link
Member

@DNXie DNXie commented Oct 3, 2025

Migrated from #177
Context: #160

Add batch routing to Service to improve request throughput and maintain session-aware routing.

  • Added new @service_endpoint decorator that supports routing configuration (router, batch_size, batch_timeout).

  • Introduced ServiceEndpointProperty to distinguish between @endpoint and @service_endpoint.

  • Centralized endpoint-to-router mapping in Service (self.routers) with support for both plain routers and batchers.

  • Updated ServiceInterface to register endpoints through _set_router, ensuring consistent setup for both standard and service endpoints.

  • Extended _call and _get_replica to handle batch routing, session routing, and fallback routing in a unified way.

  • Enhanced Service.stop to gracefully shut down any active batchers in addition to replicas.

  • Added integration tests to validate:

    • Round-robin distribution with and without batching
    • Correct batch flushing when batch_size is reached
    • Independent coexistence of multiple endpoints with different batch sizes/routers

Test

pytest tests/unit_tests/test_service.py
pytest tests/unit_tests/test_router.py
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
Instead of selecting a replica immediately, incoming requests are enqueued
and grouped into batches. Once a batch is ready (either reaching the maximum
size or exceeding the maximum wait time), the batcher makes a single routing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to base the wait time on the status of the replica instead of a fixed time? For example, if the replica is still busy, we can let the batch grow larger, but if the replica is free for some minimum time interval, then we can send the batch.

Copy link
Member Author

@DNXie DNXie Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely. We could make the batch timeout adaptive based on replica load (e.g., wait longer when replicas are busy and flush earlier when they’re idle). I’d prefer to land this current version first, then explore that as a follow-up improvement once the base batching logic is stable. Just added a TODO in the while loop.

# TODO: make timeout adaptive based on replica load.

@DNXie DNXie changed the title [WIP] Add Batch routing support via @service_endpoint with configurable batch size and timeout Add Batch routing support via @service_endpoint with configurable batch size and timeout Oct 8, 2025
session_id=None,
function=self.function,
args=args,
kwargs={},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allenwang28 Do we want to support kwargs here?

results = [results] * len(batch)
else:
# scalar result, broadcast to batch size
results = [results] * len(batch)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allenwang28 Do we want to handle when the returned results have different length or a scalar?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both the batching logic and the unbatching logic should be factored out.
The batcher should accept two functions
batch_fn :: [Request] => BatchRequest | Request
unbatch_fn :: BatchResponse | Response | => [Response]

@DNXie DNXie marked this pull request as ready for review October 8, 2025 19:33
@DNXie DNXie requested a review from allenwang28 October 8, 2025 19:33
self.service = service
self.endpoint_name = endpoint_name

async def route(self, *args: P.args, **kwargs: P.kwargs) -> R:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to schedule the task to start immediately by making this return a future?
In particular we can fire and forget without awaiting it - this is actually what happens with monarch native actor api iiuc.

"Services only support route() and fanout()."
)

async def generate(self, *args: P.args, **kwargs: P.kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

"Services only support route() and fanout()."
)

async def generate(self, *args: P.args, **kwargs: P.kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actor also doesn't have generate?

Comment on lines +145 to +153
def __init__(
self,
function: str,
inner_router: Router,
get_healthy_replicas: Callable[[], List["Replica"]],
get_session_map: Callable[[], Dict[str, int]],
batch_size: int = 16,
batch_timeout: float = 0.01,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(
self,
function: str,
inner_router: Router,
get_healthy_replicas: Callable[[], List["Replica"]],
get_session_map: Callable[[], Dict[str, int]],
batch_size: int = 16,
batch_timeout: float = 0.01,
):
def __init__(
self,
function: str,
inner_router: Router,
*,
get_healthy_replicas: Callable[[], List["Replica"]],
get_session_map: Callable[[], Dict[str, int]],
batch_size: int = 16,
batch_timeout: float = 0.01,
):

Comment on lines +24 to +28
def get_replica(
self,
healthy_replicas: List[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_replica(
self,
healthy_replicas: List[Replica],
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,
def get_replica(
self,
healthy_replicas: List[Replica],
*,
sess_id: str | None = None,
session_map: Dict[str, int] | None = None,

"""Add (args, kwargs) pair to queue, return a Future resolved when batch completes."""
# Queue the request for batching
fut = asyncio.Future()
self._queue.put_nowait((function, args, kwargs, fut))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why put_nowait? if there is no capacity limit for the queue, there isn't much of a difference. if there is, this will throw an Error?

Args:
inner_router: The underlying Router used to pick a replica.
get_healthy_replicas: Callable that returns the current list of healthy replicas.
get_session_map: Callable that returns the session-to-replica mapping.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allenwang28 this is a question for you. why doesn't the SessionRouter own the session_map in the first place? Instead, we have to pass this down layer by layer?

Comment on lines +227 to +230
# One routing decision for the whole batch
replica = self.inner_router.get_replica(
healthy_replicas, None, session_map
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you only have one (batched) request in the first place, why do you need to keep track of the replica?

self._session_router = SessionRouter(fallback_router=self._default_router)

# This keeps the map between the registered endpoints and the routers
self.routers: dict[str, Router | Batcher] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think Batcher should be considered a router. See above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.routers: dict[str, Router | Batcher] = {}
self._routers: dict[str, Router] = {}
self._batchers: dict[str, Batcher] = {}

assert set(session_map.values()) == {0, 1}

# If LeastLoadedRouter as fallback, r1 and r2 should be assigned to same replicas
replicas = [make_replica(0, load=0), make_replica(1, load=5)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please test for behaviors.
e.g. if there is a bug in the code so that self.active_requests never gets incremented, this won't catch it.
also, if the implementation of LeastLoadedRouter gets changed, this test will unnecessarily break.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants