Skip to content

Commit 78cec9b

Browse files
sanggustipre-commit-ci[bot]bhimrazy
authored
Feat/workers per api instance (#646)
* Feat: Add test on MultiRouteAPI with configurations of workers * Feat: fix tests and add test resolve - `_resolve_workers_per_device_config` make `resolve_workers_per_device` into dict[api_path, workers_per_device_int] - `_inference_workers_config_for_api` to instantiate workers per device * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix formatting on FakeCtx * Fix E501 Line too long check from precommit * Fix monitor function to correctly calculate `lit_api_id` and `worker_id` by iterating through APIs since the number of workers can now vary per API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Precommit issue * Fix: add uvicorn mock to resolve systemexit * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * doc: comment commit for test * clean: remove commented out code * Docs: Add additional documentation in the script for scenario per-route and per-api position as requested --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <bhimrajyadav977@gmail.com>
1 parent 37eebd2 commit 78cec9b

File tree

2 files changed

+176
-7
lines changed

2 files changed

+176
-7
lines changed

src/litserve/server.py

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import warnings
3030
from abc import ABC, abstractmethod
3131
from collections import deque
32-
from collections.abc import Callable, Iterable, Sequence
32+
from collections.abc import Callable, Iterable, Mapping, Sequence
3333
from contextlib import asynccontextmanager
3434
from queue import Queue
3535
from typing import TYPE_CHECKING, Literal, Optional, Union
@@ -638,6 +638,29 @@ def predict(self, prompt):
638638
server = ls.LitServer(StreamingAPI(stream=True))
639639
```
640640
641+
Per-route using dict
642+
```python
643+
server = ls.LitServer(
644+
[sentiment_api, generate_api],
645+
accelerator="cuda",
646+
devices=[0, 1],
647+
workers_per_device={
648+
"/sentiment": 2, # 2 workers per GPU for sentiment
649+
"/generate": 3, # 3 workers per GPU for generation
650+
},
651+
)
652+
```
653+
654+
Per-api position
655+
```python
656+
server = ls.LitServer(
657+
[sentiment_api, generate_api],
658+
accelerator="cuda",
659+
devices=[0, 1],
660+
workers_per_device=[2, 3], # sentiment then generate (same order as API list)
661+
)
662+
```
663+
641664
Deployment:
642665
Self-hosted:
643666
```bash
@@ -797,6 +820,7 @@ def __init__(
797820
self.lit_api = lit_api
798821
self.enable_shutdown_api = enable_shutdown_api
799822
self.workers_per_device = workers_per_device
823+
self._workers_per_device_by_api_path = self._resolve_workers_per_device_config(workers_per_device)
800824
self.max_payload_size = max_payload_size
801825
self.model_metadata = model_metadata
802826
self._connector = _Connector(accelerator=accelerator, devices=devices)
@@ -822,12 +846,15 @@ def __init__(
822846
device_list = range(devices)
823847
self.devices = [self.device_identifiers(accelerator, device) for device in device_list]
824848

825-
self.inference_workers_config = self.devices * self.workers_per_device
826849
self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp")
827850
self.register_endpoints()
828851
# register middleware
829852
self._register_middleware()
830853

854+
def _inference_workers_config_for_api(self, api_path: str):
855+
wpd = self._workers_per_device_by_api_path[api_path]
856+
return self.devices * wpd
857+
831858
def launch_inference_worker(self, lit_api: LitAPI):
832859
specs = [lit_api.spec] if lit_api.spec else []
833860
for spec in specs:
@@ -839,7 +866,10 @@ def launch_inference_worker(self, lit_api: LitAPI):
839866

840867
process_list = []
841868
endpoint = lit_api.api_path.split("/")[-1]
842-
for worker_id, device in enumerate(self.inference_workers_config):
869+
870+
inference_workers_config = self._inference_workers_config_for_api(lit_api.api_path)
871+
872+
for worker_id, device in enumerate(inference_workers_config):
843873
if len(device) == 1:
844874
device = device[0]
845875

@@ -873,7 +903,8 @@ def launch_single_inference_worker(self, lit_api: LitAPI, worker_id: int):
873903
del server_copy.app, server_copy.transport_config, server_copy.litapi_connector
874904
spec.setup(server_copy)
875905

876-
device = self.inference_workers_config[worker_id]
906+
inference_workers_config = self._inference_workers_config_for_api(lit_api.api_path)
907+
device = inference_workers_config[worker_id]
877908
endpoint = lit_api.api_path.split("/")[-1]
878909
if len(device) == 1:
879910
device = device[0]
@@ -1183,6 +1214,42 @@ def _perform_graceful_shutdown(
11831214

11841215
manager.shutdown()
11851216

1217+
def _resolve_workers_per_device_config(self, workers_per_device):
1218+
"""Resolve workers_per_device into a dict[api_path, workers_per_device_int]."""
1219+
api_paths = [api.api_path for api in self.litapi_connector]
1220+
1221+
if isinstance(workers_per_device, int):
1222+
if workers_per_device < 1:
1223+
raise ValueError("workers_per_device must be >= 1")
1224+
return dict.fromkeys(api_paths, workers_per_device)
1225+
1226+
if isinstance(workers_per_device, (list, tuple)):
1227+
if len(workers_per_device) != len(api_paths):
1228+
raise ValueError(
1229+
f"workers_per_device list length must match number of APIs \n"
1230+
f"({len(api_paths)}), got {len(workers_per_device)}"
1231+
)
1232+
cfg = {}
1233+
for p, w in zip(api_paths, workers_per_device):
1234+
if not isinstance(w, int) or w < 1:
1235+
raise ValueError("workers_per_device values must be integers >= 1")
1236+
cfg[p] = w
1237+
return cfg
1238+
1239+
if isinstance(workers_per_device, Mapping):
1240+
unknown = sorted(set(workers_per_device.keys()) - set(api_paths))
1241+
if unknown:
1242+
raise ValueError(f"workers_per_device contains unknown api_path values: {unknown} (unknown api_path)")
1243+
cfg = {}
1244+
for p in api_paths:
1245+
w = workers_per_device.get(p, 1)
1246+
if not isinstance(w, int) or w < 1:
1247+
raise ValueError("workers_per_device values must be integers >= 1")
1248+
cfg[p] = w
1249+
return cfg
1250+
1251+
raise TypeError("workers_per_device must be an int, a list/tuple of ints, or a mapping of api_path -> int")
1252+
11861253
def run(
11871254
self,
11881255
host: str = "0.0.0.0",
@@ -1388,7 +1455,10 @@ def run(
13881455
sockets = [config.bind_socket()]
13891456

13901457
if num_api_servers is None:
1391-
num_api_servers = len(self.inference_workers_config)
1458+
total_workers = 0
1459+
for lit_api in self.litapi_connector:
1460+
total_workers += len(self._inference_workers_config_for_api(lit_api.api_path))
1461+
num_api_servers = total_workers
13921462

13931463
if num_api_servers < 1:
13941464
raise ValueError("num_api_servers must be greater than 0")
@@ -1528,8 +1598,25 @@ def monitor():
15281598
broken_workers[i] = proc
15291599

15301600
for idx, proc in broken_workers.items():
1531-
lit_api_id = idx // len(self.inference_workers_config)
1532-
worker_id = idx % len(self.inference_workers_config)
1601+
lit_api_id = 0
1602+
worker_id = 0
1603+
count = 0
1604+
found = False
1605+
1606+
for i, lit_api in enumerate(self.litapi_connector):
1607+
workers_conf = self._inference_workers_config_for_api(lit_api.api_path)
1608+
num_workers_for_api = len(workers_conf)
1609+
1610+
if idx < count + num_workers_for_api:
1611+
lit_api_id = i
1612+
worker_id = idx - count
1613+
found = True
1614+
break
1615+
count += num_workers_for_api
1616+
1617+
if not found:
1618+
logger.error(f"Could not map worker index {idx} to an API.")
1619+
continue
15331620

15341621
for uid, resp in self.response_buffer.items():
15351622
if resp.worker_id is None or resp.worker_id != worker_id:

tests/unit/test_lit_server.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,85 @@ async def test_worker_restart_and_server_shutdown_streaming():
837837
):
838838
resp = await ac.post("/predict", json={"input": 0})
839839
assert resp.status_code == 200
840+
841+
842+
class MultiRouteAPI(ls.test_examples.SimpleLitAPI):
843+
# Mock API for testing multi-route server behavior
844+
def __init__(self, api_path="/predict"):
845+
super().__init__(api_path=api_path)
846+
847+
848+
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
849+
@pytest.mark.parametrize(
850+
("workers_cfg", "expected_total_by_path"),
851+
[
852+
# dict: explicit per-route config
853+
({"/sentiment": 2, "/generate": 3}, {"/sentiment": 4, "/generate": 6}),
854+
# list: per-api (connector order) config
855+
([2, 3], {"/sentiment": 4, "/generate": 6}),
856+
],
857+
)
858+
def test_workers_per_device_can_be_configured_per_route(monkeypatch, workers_cfg, expected_total_by_path):
859+
monkeypatch.setattr("litserve.server.uvicorn", MagicMock())
860+
861+
sentiment = MultiRouteAPI(api_path="/sentiment")
862+
generate = MultiRouteAPI(api_path="/generate")
863+
server = LitServer([sentiment, generate], accelerator="cuda", devices=[0, 1], workers_per_device=workers_cfg)
864+
865+
created = [] # list[(api_path, worker_id, device)]
866+
867+
class FakeProcess:
868+
def __init__(self, target, args, name):
869+
# inference_worker args = (lit_api, device, worker_id, request_q, transport, ...)
870+
lit_api, device, worker_id = args[0], args[1], args[2]
871+
created.append((lit_api.api_path, worker_id, device))
872+
self.pid = 123
873+
self.name = name
874+
875+
def start(self): ...
876+
def terminate(self): ...
877+
def join(self, timeout=None): ...
878+
def is_alive(self):
879+
return False
880+
881+
def kill(self): ...
882+
883+
class FakeCtx:
884+
def Process(self, target, args, name): # noqa: N802
885+
return FakeProcess(target=target, args=args, name=name)
886+
887+
monkeypatch.setattr("litserve.server.mp.get_context", lambda *_args, **_kwargs: FakeCtx())
888+
889+
# prevent server.run() from actually running uvicorn / waiting forever
890+
server.verify_worker_status = MagicMock()
891+
server._start_server = MagicMock(return_value={})
892+
server._perform_graceful_shutdown = MagicMock()
893+
server._start_worker_monitoring = MagicMock()
894+
server._transport = MagicMock()
895+
server._shutdown_event = MagicMock()
896+
server._shutdown_event.wait = MagicMock(return_value=None) # don't block
897+
898+
# init manager + queues without real multiprocessing manager
899+
with patch("litserve.server.mp.Manager", return_value=MagicMock()):
900+
server.run(api_server_worker_type="process", generate_client_file=False)
901+
902+
# count workers created per api_path
903+
total_by_path = {}
904+
for api_path, _worker_id, _device in created:
905+
total_by_path[api_path] = total_by_path.get(api_path, 0) + 1
906+
907+
assert total_by_path == expected_total_by_path
908+
909+
910+
@pytest.mark.skipif(sys.platform == "win32", reason="Test is only for Unix")
911+
def test_workers_per_device_per_route_raises_on_unknown_route():
912+
sentiment = MultiRouteAPI(api_path="/sentiment")
913+
generate = MultiRouteAPI(api_path="/generate")
914+
915+
with pytest.raises(ValueError, match="workers_per_device.*unknown api_path"):
916+
LitServer(
917+
[sentiment, generate],
918+
accelerator="cuda",
919+
devices=[0, 1],
920+
workers_per_device={"/sentiment": 2, "/unknown": 1},
921+
)

0 commit comments

Comments
 (0)