Skip to content

Commit d66f877

Browse files
committed
Replace port remap proxy with single-thread poll+splice event loop
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 31d6932 commit d66f877

File tree

1 file changed

+117
-43
lines changed

1 file changed

+117
-43
lines changed

src/sandlock/_port_remap.py

Lines changed: 117 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -135,57 +135,131 @@ def _start_proxy(self, virtual: int, real: int, family: int) -> None:
135135

136136
self._proxy_sockets.append(listener)
137137

138-
def _proxy_loop():
139-
while not self._proxy_stop.is_set():
140-
try:
141-
client, _ = listener.accept()
142-
except socket.timeout:
143-
continue
144-
except OSError:
145-
break
146-
# Connect to the sandbox's real port
147-
try:
148-
backend = socket.socket(af, socket.SOCK_STREAM)
149-
backend_addr = "::1" if af == socket.AF_INET6 else "127.0.0.1"
150-
backend.connect((backend_addr, real))
151-
except OSError:
152-
client.close()
153-
continue
154-
# Bidirectional forwarding in threads
155-
t1 = threading.Thread(
156-
target=_forward, args=(client, backend, self._proxy_stop),
157-
daemon=True,
158-
)
159-
t2 = threading.Thread(
160-
target=_forward, args=(backend, client, self._proxy_stop),
161-
daemon=True,
162-
)
163-
t1.start()
164-
t2.start()
165-
166-
listener.close()
167-
168-
t = threading.Thread(target=_proxy_loop, daemon=True)
138+
t = threading.Thread(
139+
target=_proxy_event_loop,
140+
args=(listener, real, af, self._proxy_stop),
141+
daemon=True,
142+
)
169143
t.start()
170144
self._proxy_threads.append(t)
171145

172146

173-
def _forward(src: socket.socket, dst: socket.socket,
174-
stop: threading.Event) -> None:
175-
"""Forward data from src to dst until EOF or stop."""
147+
def _proxy_event_loop(listener: socket.socket, real_port: int,
148+
af: int, stop: threading.Event) -> None:
149+
"""Single-thread event loop: accept connections, splice data.
150+
151+
Uses poll + splice so one thread handles all connections with
152+
zero-copy forwarding. No per-connection threads needed.
153+
"""
154+
import select
155+
156+
poller = select.poll()
157+
listener_fd = listener.fileno()
158+
poller.register(listener_fd, select.POLLIN)
159+
160+
# Per-fd state: fd → (peer_fd, pipe_r, pipe_w)
161+
pipes: dict[int, tuple[int, int, int]] = {}
162+
# Track socket objects to prevent GC
163+
sockets: dict[int, socket.socket] = {}
164+
165+
def _add_pair(client: socket.socket, backend: socket.socket) -> None:
166+
client.setblocking(False)
167+
backend.setblocking(False)
168+
client.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
169+
backend.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
170+
c_fd = client.fileno()
171+
b_fd = backend.fileno()
172+
c2b_r, c2b_w = os.pipe()
173+
b2c_r, b2c_w = os.pipe()
174+
pipes[c_fd] = (b_fd, c2b_r, c2b_w)
175+
pipes[b_fd] = (c_fd, b2c_r, b2c_w)
176+
sockets[c_fd] = client
177+
sockets[b_fd] = backend
178+
poller.register(c_fd, select.POLLIN)
179+
poller.register(b_fd, select.POLLIN)
180+
181+
def _remove_fd(fd: int) -> None:
182+
if fd not in pipes:
183+
return
184+
peer_fd, pipe_r, pipe_w = pipes.pop(fd)
185+
os.close(pipe_r)
186+
os.close(pipe_w)
187+
try:
188+
poller.unregister(fd)
189+
except (KeyError, OSError):
190+
pass
191+
s = sockets.pop(fd, None)
192+
if s:
193+
try:
194+
s.close()
195+
except OSError:
196+
pass
197+
# Also remove peer
198+
if peer_fd in pipes:
199+
p_peer, p_r, p_w = pipes.pop(peer_fd)
200+
os.close(p_r)
201+
os.close(p_w)
202+
try:
203+
poller.unregister(peer_fd)
204+
except (KeyError, OSError):
205+
pass
206+
ps = sockets.pop(peer_fd, None)
207+
if ps:
208+
try:
209+
ps.close()
210+
except OSError:
211+
pass
212+
213+
backend_addr = "::1" if af == socket.AF_INET6 else "127.0.0.1"
214+
_SPLICE_F_NONBLOCK = 0x02
215+
176216
try:
177217
while not stop.is_set():
178-
data = src.recv(65536)
179-
if not data:
218+
try:
219+
events = poller.poll(500)
220+
except OSError:
180221
break
181-
dst.sendall(data)
182-
except OSError:
183-
pass
222+
for fd, event in events:
223+
if fd == listener_fd:
224+
# Accept new connection
225+
try:
226+
client, _ = listener.accept()
227+
except OSError:
228+
continue
229+
try:
230+
backend = socket.socket(af, socket.SOCK_STREAM)
231+
backend.connect((backend_addr, real_port))
232+
except OSError:
233+
client.close()
234+
continue
235+
_add_pair(client, backend)
236+
continue
237+
238+
if fd not in pipes:
239+
continue
240+
241+
if event & (select.POLLERR | select.POLLNVAL):
242+
_remove_fd(fd)
243+
continue
244+
245+
if event & (select.POLLIN | select.POLLHUP):
246+
peer_fd, pipe_r, pipe_w = pipes[fd]
247+
try:
248+
n = os.splice(fd, pipe_w, 65536,
249+
flags=_SPLICE_F_NONBLOCK)
250+
if n == 0:
251+
_remove_fd(fd)
252+
continue
253+
while n > 0:
254+
n -= os.splice(pipe_r, peer_fd, n)
255+
except BlockingIOError:
256+
pass
257+
except OSError:
258+
_remove_fd(fd)
184259
finally:
185-
try:
186-
dst.shutdown(socket.SHUT_WR)
187-
except OSError:
188-
pass
260+
for fd in list(pipes):
261+
_remove_fd(fd)
262+
listener.close()
189263

190264

191265
def get_port_map(proxy: bool = False) -> PortMap:

0 commit comments

Comments
 (0)