-
Notifications
You must be signed in to change notification settings - Fork 15
Add Batch routing support via @service_endpoint
with configurable batch size and timeout
#177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fdc8d7c
to
8ee75a0
Compare
self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) | ||
|
||
# Batching | ||
self._max_batch_size = self._cfg.max_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so a complication here is that the max_batch_size
and batch_max_wait_s
are dependent on the endpoint itself! Not a global service level concept.
so we would need to introduce some way to mark the endpoints we want batched, because not all of them should be. Here's what I had before:
class ReferenceActor(Actor):
@service_endpoint(router=BatchedRouter(timeout_in_s=5, batch_size=4)) # batch either when you hit 5s, or batch size is 4
def forward(self, token_batch: list[int]) -> list[torch.Tensor]:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. One way I can think of is to move the routing logic entirely to ServiceEndpoint?
class ServiceEndpoint(Generic[P, R]):
"""Service endpoint that owns its own routing + batching logic."""
def __init__(
self,
service,
endpoint_name: str,
router=None,
max_batch_size: int = 1,
batch_max_wait_s: float = 0.01,
):
self.service = service
self.endpoint_name = endpoint_name
self.router = router or RoundRobinRouter() # default
self.max_batch_size = max_batch_size
self.batch_max_wait_s = batch_max_wait_s
# Only set up batching infra if batching is enabled
if self.max_batch_size > 1:
self._batch_queue = asyncio.Queue()
self._batch_task = asyncio.create_task(self._batch_loop())
else:
self._batch_queue = None
self._batch_task = None
async def route(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Route request to one replica using this endpoint’s router."""
sess_id = kwargs.pop("sess_id", None)
if sess_id is not None:
replica = await self.get_replica(sess_id)
return await self.service._call(sess_id, replica, self.endpoint_name, *args, **kwargs)
if self.max_batch_size > 1:
fut = asyncio.get_event_loop().create_future()
# enqueue the future and request payload
await self._batch_queue.put((fut, args, kwargs))
return await fut
# STATELESS and no batching -> route & call immediately
replica = await self.get_replica(None)
return await self.service._call(None, replica, self.endpoint_name, *args, **kwargs)
async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]:
# stay the same because it iterates all replicas
...
async def get_replica(self, sess_id: str | None) -> "Replica":
"""Encapsulate replica selection logic (session vs stateless)."""
healthy = self.service._get_healthy_replicas()
return await self.router.get_replica(healthy, sess_id, self.service._session_replica_map)
async def _batch_loop(self):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm yeah I think that makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@allenwang28 I found out that this try-except clause actually doesn't catch error. Since I moved this logic to class Counter(ForgeActor):
@endpoint()
async def flaky(self):
if not hasattr(self, "_flaky_calls"):
self._flaky_calls = 0
self._flaky_calls += 1
if self._flaky_calls == 1:
raise RuntimeError("fail first attempt")
return "ok" @pytest.mark.timeout(10)
@pytest.mark.asyncio
async def test():
service = await Counter.options(procs=1, num_replicas=2).as_service(v=0)
try:
service.max_attempts = 2
result = await service.flaky.route()
assert result == "ok" # success after retry
finally:
await service.shutdown() But somehow it doesn't go through. It seems like this error is handled in monarch level? INFO forge.controller.actor:actor.py:123 Spawning Service Actor for Counter
INFO forge.controller.actor:actor.py:207 Spawning single actor Counter
INFO forge.controller.actor:actor.py:207 Spawning single actor Counter
INFO forge.controller.service.endpoint:endpoint.py:109 Get replica 1 on attempt 0
[0] CRITICAL:root:Unhandled exception in actor endpoint
[0] Traceback (most recent call last):
[0] File "/home/dxie/.fbpkg_conda_envs/forge-a7401c7/lib/python3.10/site-packages/monarch/_src/actor/actor_mesh.py", line 831, in instrumented
[0] result = await the_method(*args, **kwargs)
[0] File "/home/dxie/forge/tests/unit_tests/test_router.py", line 64, in flaky
[0] raise RuntimeError("fail first attempt")
[0] RuntimeError: fail first attempt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, @DNXie we should discuss this further. I understand why the router logic is placed within the endpoint, but there is a challenge. At some point in the near future, I will want the services themselves to be actors (this is what you see now as ServiceV2 - this is an implementation/design detail)
However if the endpoints themselves handle the load balancing etc., it shouldn't be owned by the ServiceInterface - it should be owned by the Service. The interface should be a really lightweight handle that can get passed around, while the true load balancing and batching logic is handled only by the service.
The idea of creating the service_endpoint decorator as we do here makes sense, having a router/batcher per endpoint makes sense, but I want all of that to materialize within the service.
Practically that probably means we take what you're calling ServiceEndpoint here, and placing that within Service, while keeping the actual ServiceEndpoints lightweight and properly connecting the dots to the Service.
Does that make sense? Either way I scheduled a time for us to chat today
Co-authored-by: Allen Wang <[email protected]>
Co-authored-by: Allen Wang <[email protected]>
…int in service. TODO: add more test cases
@service_endpoint
with configurable batch size and timeout
batch_timeout: float = 0.01, | ||
) -> None: | ||
super().__init__(method, propagator, explicit_response_port) | ||
self._service_endpoint_config = dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this need to be a dict? Can we not just do self.router = router
etc.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because if it is a batcher, we need to pass two functions get_healthy_replicas
and get_session_map
class Batcher:
def __init__(
self,
inner_router: Router,
get_healthy_replicas: Callable[[], List["Replica"]],
get_session_map: Callable[[], Dict[str, int]],
batch_size: int = 16,
batch_timeout: float = 0.01,
):
These two functions are from a service
object. So Batcher
needs to be initialized after Service is initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my question is more about why it is a dictionary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. Yeah we don't need a dict here. We can keep things flat. Done.
self._default_router = RoundRobinRouter() | ||
self._session_router = SessionRouter(fallback_router=LeastLoadedRouter()) | ||
self._session_router = SessionRouter(fallback_router=self._default_router) | ||
self.routers: dict[str, Router | Batcher] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's mark as a comment that this keeps the map between the registered endpoints and the routing functions
|
||
def service_endpoint( | ||
*, | ||
router: Router = RoundRobinRouter(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this instead be a router constructor? Like right now we pass in a full Router object - probably fine to start with, but typically for efficiency reasons we pass in a constructor function, which we create later.
So it'd look like
router: Callable[[], Router]
and
class MyForgeActor(ForgeActor):
@service_endpoint(router=RoundRobinRouter, batch_size=16, batch_timeout=0.05)
async def predict(self, x): ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Done!
|
||
# One routing decision for the whole batch | ||
replica = self.inner_router.get_replica(healthy_replicas, None, session_map) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, technically what should happen here is that a batch of requests is sent to the replica. I think what's happening right now is that you're calling the replica's endpoint batch_size
times, which is exactly the opposite of what we want!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you are right. Just updated to make the batcher send the batched requests to the replica. Please check!
# Accept requests in all other states - let the processing loop handle the rest | ||
await self.request_queue.put(request) | ||
|
||
async def enqueue_batch(self, requests: list[ServiceRequest]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, ok this is complicated. The right behavior is actually that, as a user I want to be able to say:
class MyActor(ForgeActor):
@service_endpoint(batch_size=4, duration_time_s=1) # (you can ignore the wording I'm using here)
def forward(self, batch: list[torch.Tensor]) -> torch.Tensor:
...
but be able to use it as
service = MyActor.as_service()
result: torch.Tensor = service.forward.route(torch.zeros(10))
i.e., you define it as a batch function but get individual results back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is working properly. See the test case in test_router.py::test_batch_endpoint_returns_individual_results
Now in #304
Add batch routing to
Service
to improve request throughput and maintain session-aware routing.Added new
@service_endpoint
decorator that supports routing configuration (router
,batch_size
,batch_timeout
).Introduced
ServiceEndpointProperty
to distinguish between@endpoint
and@service_endpoint
.Centralized endpoint-to-router mapping in
Service
(self.routers
) with support for both plain routers and batchers.Updated
ServiceInterface
to register endpoints through_set_router
, ensuring consistent setup for both standard and service endpoints.Extended
_call
and_get_replica
to handle batch routing, session routing, and fallback routing in a unified way.Enhanced
Service.stop
to gracefully shut down any active batchers in addition to replicas.Added integration tests to validate:
batch_size
is reachedTest