Skip to content

Commit 5aef366

Browse files
authored
enable multiple workers for ZMQ (#411)
* minor refactor * fix test * add zmq * remove broker from server_copy * enable multi workers * test e2e * add tests * add tests * add tests * update docs * enable uvicorn workers * enable zmq for parity test * apply feedback
1 parent d1fefd7 commit 5aef366

17 files changed

+478
-311
lines changed

src/litserve/loops/base.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,20 @@
1414
import asyncio
1515
import inspect
1616
import logging
17+
import signal
1718
import sys
1819
import time
1920
from abc import ABC
2021
from queue import Empty, Queue
2122
from typing import Any, Dict, List, Optional, Tuple, Union
2223

23-
import zmq
2424
from starlette.formparsers import MultiPartParser
2525

2626
from litserve import LitAPI
2727
from litserve.callbacks import CallbackRunner
2828
from litserve.specs.base import LitSpec
2929
from litserve.utils import LitAPIStatus
30+
from litserve.zmq_queue import Producer
3031

3132
logger = logging.getLogger(__name__)
3233
# FastAPI writes form files to disk over 1MB by default, which prevents serialization by multiprocessing
@@ -129,9 +130,6 @@ def run(
129130
130131
"""
131132

132-
zmq_ctx: Optional[zmq.Context] = None
133-
socket: Optional[zmq.Socket] = None
134-
135133
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
136134
pass
137135

@@ -159,9 +157,7 @@ def __call__(
159157
stream: bool,
160158
workers_setup_status: Dict[int, str],
161159
callback_runner: CallbackRunner,
162-
socket: Optional[zmq.Socket],
163160
):
164-
self.socket = socket
165161
if asyncio.iscoroutinefunction(self.run):
166162
event_loop = asyncio.new_event_loop()
167163

@@ -226,7 +222,9 @@ def run(
226222

227223
class LitLoop(_BaseLoop):
228224
def __init__(self):
225+
self.producer: Optional[Producer] = None
229226
self._context = {}
227+
self._setup_signal_handlers()
230228

231229
def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
232230
batches, timed_out_uids = collate_requests(
@@ -250,23 +248,29 @@ def populate_context(self, lit_spec: LitSpec, request: Any):
250248
def put_response(
251249
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
252250
) -> None:
253-
if self.socket:
254-
self.socket.send_pyobj((uid, (response_data, status)))
251+
if self.producer:
252+
self.producer.put((uid, (response_data, status)), consumer_id=response_queue_id)
255253
else:
256254
response_queues[response_queue_id].put((uid, (response_data, status)), block=False)
257255

258256
def put_error_response(
259257
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
260258
) -> None:
261-
if self.socket:
262-
self.socket.send_pyobj((uid, (error, LitAPIStatus.ERROR)))
263-
else:
264-
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)), block=False)
259+
self.put_response(response_queues, response_queue_id, uid, error, LitAPIStatus.ERROR)
265260

266261
def __del__(self):
267-
if self.socket:
268-
self.socket.close(linger=0)
269-
self.zmq_ctx.term()
262+
if self.producer:
263+
self.producer.close()
264+
265+
def _setup_signal_handlers(self):
266+
def cleanup_handler(signum=None, frame=None):
267+
logging.debug("Worker process received shutdown signal")
268+
if self.producer:
269+
self.producer.close()
270+
sys.exit(0)
271+
272+
signal.signal(signal.SIGINT, cleanup_handler)
273+
signal.signal(signal.SIGTERM, cleanup_handler)
270274

271275

272276
class DefaultLoop(LitLoop):

src/litserve/loops/loops.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,14 @@
1515
from queue import Queue
1616
from typing import Dict, List, Optional, Union
1717

18-
import zmq
19-
import zmq.asyncio
20-
2118
from litserve import LitAPI
2219
from litserve.callbacks import CallbackRunner, EventTypes
2320
from litserve.loops.base import _BaseLoop
2421
from litserve.loops.simple_loops import BatchedLoop, SingleLoop
2522
from litserve.loops.streaming_loops import BatchedStreamingLoop, StreamingLoop
2623
from litserve.specs.base import LitSpec
2724
from litserve.utils import WorkerSetupStatus
25+
from litserve.zmq_queue import Producer
2826

2927
logger = logging.getLogger(__name__)
3028

@@ -78,14 +76,10 @@ def inference_worker(
7876
if loop == "auto":
7977
loop = get_default_loop(stream, max_batch_size)
8078

81-
socket = None
8279
if use_zmq:
83-
ctx = zmq.Context()
84-
socket = ctx.socket(zmq.PUB)
85-
logger.debug(f"Inference worker binding to {zmq_addr}")
86-
socket.bind(zmq_addr)
87-
loop.socket = socket
88-
loop.zmq_context = ctx
80+
producer = Producer(address=zmq_addr)
81+
producer.wait_for_subscribers(timeout=5)
82+
loop.producer = producer
8983

9084
loop(
9185
lit_api,
@@ -99,8 +93,4 @@ def inference_worker(
9993
stream,
10094
workers_setup_status,
10195
callback_runner,
102-
socket,
10396
)
104-
if use_zmq:
105-
socket.close()
106-
loop.zmq_context.term()

src/litserve/loops/simple_loops.py

Lines changed: 14 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from queue import Empty, Queue
1717
from typing import Dict, List, Optional
1818

19-
import zmq
2019
from fastapi import HTTPException
2120

2221
from litserve import LitAPI
@@ -28,116 +27,14 @@
2827
logger = logging.getLogger(__name__)
2928

3029

31-
def run_batched_loop(
32-
lit_api: LitAPI,
33-
lit_spec: LitSpec,
34-
request_queue: Queue,
35-
response_queues: List[Queue],
36-
max_batch_size: int,
37-
batch_timeout: float,
38-
callback_runner: CallbackRunner,
39-
socket: Optional[zmq.Socket],
40-
):
41-
while True:
42-
batches, timed_out_uids = collate_requests(
43-
lit_api,
44-
request_queue,
45-
max_batch_size,
46-
batch_timeout,
47-
)
48-
49-
for response_queue_id, uid in timed_out_uids:
50-
logger.error(
51-
f"Request {uid} was waiting in the queue for too long ({lit_api.request_timeout} seconds) and "
52-
"has been timed out. "
53-
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
54-
)
55-
if socket:
56-
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
57-
else:
58-
response_queues[response_queue_id].put((
59-
uid,
60-
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
61-
))
62-
63-
if not batches:
64-
continue
65-
logger.debug(f"{len(batches)} batched requests received")
66-
response_queue_ids, uids, inputs = zip(*batches)
67-
num_inputs = len(inputs)
68-
try:
69-
contexts = [{}] * num_inputs
70-
if hasattr(lit_spec, "populate_context"):
71-
for input, context in zip(inputs, contexts):
72-
lit_spec.populate_context(context, input)
73-
74-
callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
75-
x = [
76-
_inject_context(
77-
context,
78-
lit_api.decode_request,
79-
input,
80-
)
81-
for input, context in zip(inputs, contexts)
82-
]
83-
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)
84-
85-
x = lit_api.batch(x)
86-
87-
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
88-
y = _inject_context(contexts, lit_api.predict, x)
89-
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
90-
91-
outputs = lit_api.unbatch(y)
92-
93-
if len(outputs) != num_inputs:
94-
logger.error(
95-
"LitAPI.predict/unbatch returned {len(outputs)} outputs, but expected {num_inputs}. "
96-
"Please check the predict/unbatch method of the LitAPI implementation."
97-
)
98-
raise HTTPException(500, "Batch size mismatch")
99-
100-
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
101-
y_enc_list = []
102-
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
103-
y_enc = _inject_context(context, lit_api.encode_response, y)
104-
y_enc_list.append((response_queue_id, uid, y_enc))
105-
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
106-
107-
for response_queue_id, uid, y_enc in y_enc_list:
108-
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
109-
110-
except HTTPException as e:
111-
for response_queue_id, uid in zip(response_queue_ids, uids):
112-
if socket:
113-
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
114-
else:
115-
response_queues[response_queue_id].put((
116-
uid,
117-
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
118-
))
119-
120-
except Exception as e:
121-
logger.exception(
122-
"LitAPI ran into an error while processing the batched request.\n"
123-
"Please check the error trace for more details."
124-
)
125-
for response_queue_id, uid in zip(response_queue_ids, uids):
126-
if socket:
127-
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
128-
else:
129-
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
130-
131-
13230
class SingleLoop(DefaultLoop):
13331
def run_single_loop(
13432
self,
13533
lit_api: LitAPI,
136-
lit_spec: LitSpec,
34+
lit_spec: Optional[LitSpec],
13735
request_queue: Queue,
13836
response_queues: List[Queue],
13937
callback_runner: CallbackRunner,
140-
socket: Optional[zmq.Socket],
14138
):
14239
while True:
14340
try:
@@ -233,9 +130,8 @@ def __call__(
233130
stream: bool,
234131
workers_setup_status: Dict[int, str],
235132
callback_runner: CallbackRunner,
236-
socket: Optional[zmq.Socket],
237133
):
238-
self.run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner, socket)
134+
self.run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
239135

240136

241137
class BatchedLoop(DefaultLoop):
@@ -248,7 +144,6 @@ def run_batched_loop(
248144
max_batch_size: int,
249145
batch_timeout: float,
250146
callback_runner: CallbackRunner,
251-
socket: Optional[zmq.Socket],
252147
):
253148
while True:
254149
batches, timed_out_uids = collate_requests(
@@ -264,13 +159,9 @@ def run_batched_loop(
264159
"has been timed out. "
265160
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
266161
)
267-
if socket:
268-
socket.send_pyobj((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))
269-
else:
270-
response_queues[response_queue_id].put((
271-
uid,
272-
(HTTPException(504, "Request timed out"), LitAPIStatus.ERROR),
273-
))
162+
self.put_response(
163+
response_queues, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
164+
)
274165

275166
if not batches:
276167
continue
@@ -317,28 +208,25 @@ def run_batched_loop(
317208
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
318209

319210
for response_queue_id, uid, y_enc in y_enc_list:
320-
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
211+
self.put_response(response_queues, response_queue_id, uid, y_enc, LitAPIStatus.OK)
321212

322213
except HTTPException as e:
323214
for response_queue_id, uid in zip(response_queue_ids, uids):
324-
if socket:
325-
socket.send_pyobj((uid, (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR)))
326-
else:
327-
response_queues[response_queue_id].put((
328-
uid,
329-
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
330-
))
215+
self.put_response(
216+
response_queues,
217+
response_queue_id,
218+
uid,
219+
PickleableHTTPException.from_exception(e),
220+
LitAPIStatus.ERROR,
221+
)
331222

332223
except Exception as e:
333224
logger.exception(
334225
"LitAPI ran into an error while processing the batched request.\n"
335226
"Please check the error trace for more details."
336227
)
337228
for response_queue_id, uid in zip(response_queue_ids, uids):
338-
if socket:
339-
socket.send_pyobj((uid, (e, LitAPIStatus.ERROR)))
340-
else:
341-
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))
229+
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
342230

343231
def __call__(
344232
self,
@@ -353,7 +241,6 @@ def __call__(
353241
stream: bool,
354242
workers_setup_status: Dict[int, str],
355243
callback_runner: CallbackRunner,
356-
socket: Optional[zmq.Socket],
357244
):
358245
self.run_batched_loop(
359246
lit_api,
@@ -363,5 +250,4 @@ def __call__(
363250
max_batch_size,
364251
batch_timeout,
365252
callback_runner,
366-
socket,
367253
)

0 commit comments

Comments
 (0)