Skip to content

Commit cb804d7

Browse files
authored
feat: optimize zmq receive (#131)
* updates * ci fix
1 parent 0ceafb4 commit cb804d7

File tree

2 files changed

+44
-32
lines changed

2 files changed

+44
-32
lines changed

src/inference_endpoint/async_utils/transport/zmq/transport.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ class _ZMQSocketConfig:
9595
recv_buffer_size: int = 4 * 1024 * 1024 # 4MB
9696
send_buffer_size: int = 4 * 1024 * 1024 # 4MB
9797

98+
def apply_recv(self, sock: zmq.Socket) -> None:
99+
"""Apply receiver socket options."""
100+
sock.setsockopt(zmq.LINGER, self.linger)
101+
sock.setsockopt(zmq.RCVHWM, self.high_water_mark)
102+
sock.setsockopt(zmq.RCVBUF, self.recv_buffer_size)
103+
104+
def apply_send(self, sock: zmq.Socket) -> None:
105+
"""Apply sender socket options."""
106+
sock.setsockopt(zmq.LINGER, self.linger)
107+
sock.setsockopt(zmq.SNDHWM, self.high_water_mark)
108+
sock.setsockopt(zmq.SNDBUF, self.send_buffer_size)
109+
sock.setsockopt(zmq.IMMEDIATE, self.immediate)
110+
98111

99112
class _ZmqReceiverTransport(ReceiverTransport):
100113
"""
@@ -122,6 +135,8 @@ class _ZmqReceiverTransport(ReceiverTransport):
122135
"_waiter",
123136
"_closing",
124137
"_soon_call",
138+
"_recv_buf",
139+
"_recv_view",
125140
)
126141

127142
def __init__(
@@ -140,6 +155,13 @@ def __init__(
140155
self._closing = False
141156
self._soon_call: asyncio.Handle | None = None
142157

158+
# NOTE(vir):
159+
# zmq recv_into with Pre-allocated buffer.
160+
# msgspec can decode in-place, avoiding per-message bytes allocation.
161+
recv_buf_size = sock.getsockopt(zmq.RCVBUF)
162+
self._recv_buf = bytearray(recv_buf_size)
163+
self._recv_view = memoryview(self._recv_buf)
164+
143165
self._loop.add_reader(self._fd, self._on_readable)
144166

145167
def _on_readable(self) -> None:
@@ -170,18 +192,26 @@ def _on_readable(self) -> None:
170192
return
171193

172194
count = 0
195+
recv_buf = self._recv_buf
196+
recv_view = self._recv_view
197+
buf_len = len(recv_buf)
173198
try:
174199
while True:
175-
data = self._sock.recv(zmq.NOBLOCK, copy=False, track=False)
176-
self._deque.append(self._decoder.decode(data))
200+
nbytes = self._sock.recv_into(recv_buf, flags=zmq.NOBLOCK)
201+
if nbytes > buf_len:
202+
raise RuntimeError(
203+
f"ZMQ message truncated ({nbytes} > {buf_len} bytes). "
204+
f"Increase recv_buffer_size in _ZMQSocketConfig."
205+
)
206+
self._deque.append(self._decoder.decode(recv_view[:nbytes]))
177207
count += 1
178208
except zmq.Again:
179209
# Normal: no more messages
180210
pass
181211
except zmq.ZMQError as e:
182212
if e.errno not in (errno.EAGAIN, errno.EINTR, errno.ENOTSOCK):
183213
logger.error(f"ZMQ recv error: {e}")
184-
except Exception as e:
214+
except msgspec.DecodeError as e:
185215
logger.error(f"Decode error: {e}")
186216

187217
# Wake waiter once after draining (not per message)
@@ -402,9 +432,7 @@ def _create_receiver(
402432
Configured receiver transport.
403433
"""
404434
sock = zmq_context.socket(zmq.PULL)
405-
sock.setsockopt(zmq.LINGER, config.linger)
406-
sock.setsockopt(zmq.RCVHWM, config.high_water_mark)
407-
sock.setsockopt(zmq.RCVBUF, config.recv_buffer_size)
435+
config.apply_recv(sock)
408436

409437
if bind:
410438
sock.bind(address)
@@ -429,10 +457,7 @@ def _create_sender(
429457
) -> _ZmqSenderTransport:
430458
"""Create a ZMQ sender transport."""
431459
sock = zmq_context.socket(zmq.PUSH)
432-
sock.setsockopt(zmq.LINGER, config.linger)
433-
sock.setsockopt(zmq.SNDHWM, config.high_water_mark)
434-
sock.setsockopt(zmq.SNDBUF, config.send_buffer_size)
435-
sock.setsockopt(zmq.IMMEDIATE, config.immediate)
460+
config.apply_send(sock)
436461

437462
if bind:
438463
sock.bind(address)

tests/performance/async_utils/transport/test_zmq.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
import msgspec
2929
import pytest
3030
import uvloop
31-
import zmq
3231
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
3332
from inference_endpoint.async_utils.transport.zmq.transport import (
34-
_ZmqReceiverTransport,
35-
_ZmqSenderTransport,
33+
_create_receiver,
34+
_create_sender,
35+
_ZMQSocketConfig,
3636
)
3737
from inference_endpoint.core.types import Query, QueryResult, StreamChunk
3838

@@ -44,7 +44,6 @@
4444
TEST_DURATION_SECONDS = 5.0
4545

4646
WARMUP_MESSAGES = 100
47-
BUFFER_SIZE = 10 * 1024 * 1024
4847

4948
# Payload sizes in chars
5049
PAYLOAD_SIZES_CHARS = [32, 128, 512, 1024, 4096, 16384, 32768]
@@ -140,27 +139,15 @@ async def benchmark(
140139

141140
loop = asyncio.get_running_loop()
142141

143-
with ManagedZMQContext.scoped(io_threads=4) as zmq_ctx:
142+
config = _ZMQSocketConfig()
143+
144+
with ManagedZMQContext.scoped(io_threads=config.io_threads) as zmq_ctx:
144145
with tempfile.TemporaryDirectory(prefix="zmq_") as tmp:
145146
addr = f"ipc://{tmp}/bench"
146147

147-
# Sender (main proc perspective)
148-
push = zmq_ctx.socket(zmq.PUSH)
149-
push.setsockopt(zmq.LINGER, -1)
150-
push.setsockopt(zmq.SNDHWM, 0)
151-
push.setsockopt(zmq.SNDBUF, BUFFER_SIZE)
152-
push.setsockopt(zmq.IMMEDIATE, 1)
153-
push.bind(addr)
154-
sender = _ZmqSenderTransport(loop, push, msgspec.msgpack.Encoder())
155-
156-
# Receiver (worker proc perspective)
157-
pull = zmq_ctx.socket(zmq.PULL)
158-
pull.setsockopt(zmq.LINGER, -1)
159-
pull.setsockopt(zmq.RCVHWM, 0)
160-
pull.setsockopt(zmq.RCVBUF, BUFFER_SIZE)
161-
pull.connect(addr)
162-
receiver = _ZmqReceiverTransport(
163-
loop, pull, msgspec.msgpack.Decoder(type=msg_type)
148+
sender = _create_sender(loop, addr, zmq_ctx, config, bind=True)
149+
receiver = _create_receiver(
150+
loop, addr, zmq_ctx, config, msg_type, bind=False
164151
)
165152

166153
await asyncio.sleep(0.01)

0 commit comments

Comments
 (0)