Skip to content

Conversation

DNXie
Copy link
Member

@DNXie DNXie commented Sep 18, 2025

Now in #304


Add batch routing to Service to improve request throughput and maintain session-aware routing.

  • Added new @service_endpoint decorator that supports routing configuration (router, batch_size, batch_timeout).

  • Introduced ServiceEndpointProperty to distinguish between @endpoint and @service_endpoint.

  • Centralized endpoint-to-router mapping in Service (self.routers) with support for both plain routers and batchers.

  • Updated ServiceInterface to register endpoints through _set_router, ensuring consistent setup for both standard and service endpoints.

  • Extended _call and _get_replica to handle batch routing, session routing, and fallback routing in a unified way.

  • Enhanced Service.stop to gracefully shut down any active batchers in addition to replicas.

  • Added integration tests to validate:

    • Round-robin distribution with and without batching
    • Correct batch flushing when batch_size is reached
    • Independent coexistence of multiple endpoints with different batch sizes/routers

Test

pytest tests/unit_tests/test_service.py
pytest tests/unit_tests/test_router.py
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

@DNXie DNXie requested a review from allenwang28 September 18, 2025 19:42
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 18, 2025
@DNXie DNXie requested a review from allenwang28 September 22, 2025 22:58
@DNXie DNXie force-pushed the batch_router branch 2 times, most recently from fdc8d7c to 8ee75a0 Compare September 24, 2025 18:46
@DNXie DNXie changed the title Add BatchRouter for async batch dispatch Support batch routing into Service with configurable batch size and timeout Sep 24, 2025
self._session_router = SessionRouter(fallback_router=LeastLoadedRouter())

# Batching
self._max_batch_size = self._cfg.max_batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so a complication here is that the max_batch_size and batch_max_wait_s are dependent on the endpoint itself! Not a global service level concept.

so we would need to introduce some way to mark the endpoints we want batched, because not all of them should be. Here's what I had before:

class ReferenceActor(Actor):
    @service_endpoint(router=BatchedRouter(timeout_in_s=5, batch_size=4)) # batch either when you hit 5s, or batch size is 4
    def forward(self, token_batch: list[int]) -> list[torch.Tensor]:
         ...

Copy link
Member Author

@DNXie DNXie Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. One way I can think of is to move the routing logic entirely to ServiceEndpoint?

class ServiceEndpoint(Generic[P, R]):
    """Service endpoint that owns its own routing + batching logic."""

    def __init__(
        self,
        service,
        endpoint_name: str,
        router=None,
        max_batch_size: int = 1,
        batch_max_wait_s: float = 0.01,
    ):
        self.service = service
        self.endpoint_name = endpoint_name
        self.router = router or RoundRobinRouter()  # default
        self.max_batch_size = max_batch_size
        self.batch_max_wait_s = batch_max_wait_s

        # Only set up batching infra if batching is enabled
        if self.max_batch_size > 1:
            self._batch_queue = asyncio.Queue()
            self._batch_task = asyncio.create_task(self._batch_loop())
        else:
            self._batch_queue = None
            self._batch_task = None

    async def route(self, *args: P.args, **kwargs: P.kwargs) -> R:
        """Route request to one replica using this endpoint’s router."""
        sess_id = kwargs.pop("sess_id", None)

        if sess_id is not None:
            replica = await self.get_replica(sess_id)
            return await self.service._call(sess_id, replica, self.endpoint_name, *args, **kwargs)

        if self.max_batch_size > 1:
            fut = asyncio.get_event_loop().create_future()
            # enqueue the future and request payload
            await self._batch_queue.put((fut, args, kwargs))
            return await fut

        # STATELESS and no batching -> route & call immediately
        replica = await self.get_replica(None)
        return await self.service._call(None, replica, self.endpoint_name, *args, **kwargs)


    async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
        # stay the same because it iterates all replicas
        ...

    async def get_replica(self, sess_id: str | None) -> "Replica":
        """Encapsulate replica selection logic (session vs stateless)."""
        healthy = self.service._get_healthy_replicas()

        return await self.router.get_replica(healthy, sess_id, self.service._session_replica_map)

    async def _batch_loop(self):
       ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah I think that makes sense!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@DNXie DNXie requested a review from allenwang28 September 26, 2025 04:04
@DNXie
Copy link
Member Author

DNXie commented Sep 26, 2025

@allenwang28 I found out that this try-except clause actually doesn't catch error.

Since I moved this logic to ServiceEndpoint with self.max_attempts=2. Our existing test cases doesn't cover this logic. I was trying to test whether it really works with this test case:

class Counter(ForgeActor):
  @endpoint()
  async def flaky(self):
      if not hasattr(self, "_flaky_calls"):
          self._flaky_calls = 0
      self._flaky_calls += 1
      if self._flaky_calls == 1:
          raise RuntimeError("fail first attempt")
      return "ok"
@pytest.mark.timeout(10)
@pytest.mark.asyncio
async def test():
    service = await Counter.options(procs=1, num_replicas=2).as_service(v=0)

    try:
        service.max_attempts = 2

        result = await service.flaky.route()
        assert result == "ok"  # success after retry
    finally:
        await service.shutdown()

But somehow it doesn't go through. It seems like this error is handled in monarch level?

INFO     forge.controller.actor:actor.py:123 Spawning Service Actor for Counter
INFO     forge.controller.actor:actor.py:207 Spawning single actor Counter
INFO     forge.controller.actor:actor.py:207 Spawning single actor Counter
INFO     forge.controller.service.endpoint:endpoint.py:109 Get replica 1 on attempt 0
[0] CRITICAL:root:Unhandled exception in actor endpoint
[0] Traceback (most recent call last):
[0]   File "/home/dxie/.fbpkg_conda_envs/forge-a7401c7/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 831, in instrumented
[0]     result = await the_method(*args, **kwargs)
[0]   File "/home/dxie/forge/tests/unit_tests/test_router.py", line 64, in flaky
[0]     raise RuntimeError("fail first attempt")
[0] RuntimeError: fail first attempt

Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, @DNXie we should discuss this further. I understand why the router logic is placed within the endpoint, but there is a challenge. At some point in the near future, I will want the services themselves to be actors (this is what you see now as ServiceV2 - this is an implementation/design detail)

However if the endpoints themselves handle the load balancing etc., it shouldn't be owned by the ServiceInterface - it should be owned by the Service. The interface should be a really lightweight handle that can get passed around, while the true load balancing and batching logic is handled only by the service.

The idea of creating the service_endpoint decorator as we do here makes sense, having a router/batcher per endpoint makes sense, but I want all of that to materialize within the service.

Practically that probably means we take what you're calling ServiceEndpoint here, and placing that within Service, while keeping the actual ServiceEndpoints lightweight and properly connecting the dots to the Service.

Does that make sense? Either way I scheduled a time for us to chat today

@DNXie DNXie changed the title Support batch routing into Service with configurable batch size and timeout Add Batch routing support via @service_endpoint with configurable batch size and timeout Sep 29, 2025
batch_timeout: float = 0.01,
) -> None:
super().__init__(method, propagator, explicit_response_port)
self._service_endpoint_config = dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this need to be a dict? Can we not just do self.router = router etc.?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if it is a batcher, we need to pass two functions get_healthy_replicas and get_session_map

class Batcher:
    def __init__(
        self,
        inner_router: Router,
        get_healthy_replicas: Callable[[], List["Replica"]],
        get_session_map: Callable[[], Dict[str, int]],
        batch_size: int = 16,
        batch_timeout: float = 0.01,
    ):

These two functions are from a service object. So Batcher needs to be initialized after Service is initialized.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question is more about why it is a dictionary

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Yeah we don't need a dict here. We can keep things flat. Done.

self._default_router = RoundRobinRouter()
self._session_router = SessionRouter(fallback_router=LeastLoadedRouter())
self._session_router = SessionRouter(fallback_router=self._default_router)
self.routers: dict[str, Router | Batcher] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's mark as a comment that this keeps the map between the registered endpoints and the routing functions

@DNXie DNXie requested a review from allenwang28 September 30, 2025 00:36

def service_endpoint(
*,
router: Router = RoundRobinRouter(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this instead be a router constructor? Like right now we pass in a full Router object - probably fine to start with, but typically for efficiency reasons we pass in a constructor function, which we create later.

So it'd look like

router: Callable[[], Router]

and

        class MyForgeActor(ForgeActor):
            @service_endpoint(router=RoundRobinRouter, batch_size=16, batch_timeout=0.05)
            async def predict(self, x): ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Done!


# One routing decision for the whole batch
replica = self.inner_router.get_replica(healthy_replicas, None, session_map)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, technically what should happen here is that a batch of requests is sent to the replica. I think what's happening right now is that you're calling the replica's endpoint batch_size times, which is exactly the opposite of what we want!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you are right. Just updated to make the batcher send the batched requests to the replica. Please check!

@DNXie DNXie requested a review from allenwang28 September 30, 2025 21:31
# Accept requests in all other states - let the processing loop handle the rest
await self.request_queue.put(request)

async def enqueue_batch(self, requests: list[ServiceRequest]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, ok this is complicated. The right behavior is actually that, as a user I want to be able to say:

class MyActor(ForgeActor):
    @service_endpoint(batch_size=4, duration_time_s=1) # (you can ignore the wording I'm using here)
    def forward(self, batch: list[torch.Tensor]) -> torch.Tensor:
        ...

but be able to use it as

service = MyActor.as_service()

result: torch.Tensor = service.forward.route(torch.zeros(10))

i.e., you define it as a batch function but get individual results back

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is working properly. See the test case in test_router.py::test_batch_endpoint_returns_individual_results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants