@@ -135,57 +135,131 @@ def _start_proxy(self, virtual: int, real: int, family: int) -> None:
135135
136136 self ._proxy_sockets .append (listener )
137137
138- def _proxy_loop ():
139- while not self ._proxy_stop .is_set ():
140- try :
141- client , _ = listener .accept ()
142- except socket .timeout :
143- continue
144- except OSError :
145- break
146- # Connect to the sandbox's real port
147- try :
148- backend = socket .socket (af , socket .SOCK_STREAM )
149- backend_addr = "::1" if af == socket .AF_INET6 else "127.0.0.1"
150- backend .connect ((backend_addr , real ))
151- except OSError :
152- client .close ()
153- continue
154- # Bidirectional forwarding in threads
155- t1 = threading .Thread (
156- target = _forward , args = (client , backend , self ._proxy_stop ),
157- daemon = True ,
158- )
159- t2 = threading .Thread (
160- target = _forward , args = (backend , client , self ._proxy_stop ),
161- daemon = True ,
162- )
163- t1 .start ()
164- t2 .start ()
165-
166- listener .close ()
167-
168- t = threading .Thread (target = _proxy_loop , daemon = True )
138+ t = threading .Thread (
139+ target = _proxy_event_loop ,
140+ args = (listener , real , af , self ._proxy_stop ),
141+ daemon = True ,
142+ )
169143 t .start ()
170144 self ._proxy_threads .append (t )
171145
172146
173- def _forward (src : socket .socket , dst : socket .socket ,
174- stop : threading .Event ) -> None :
175- """Forward data from src to dst until EOF or stop."""
147+ def _proxy_event_loop (listener : socket .socket , real_port : int ,
148+ af : int , stop : threading .Event ) -> None :
149+ """Single-thread event loop: accept connections, splice data.
150+
151+ Uses poll + splice so one thread handles all connections with
152+ zero-copy forwarding. No per-connection threads needed.
153+ """
154+ import select
155+
156+ poller = select .poll ()
157+ listener_fd = listener .fileno ()
158+ poller .register (listener_fd , select .POLLIN )
159+
160+ # Per-fd state: fd → (peer_fd, pipe_r, pipe_w)
161+ pipes : dict [int , tuple [int , int , int ]] = {}
162+ # Track socket objects to prevent GC
163+ sockets : dict [int , socket .socket ] = {}
164+
165+ def _add_pair (client : socket .socket , backend : socket .socket ) -> None :
166+ client .setblocking (False )
167+ backend .setblocking (False )
168+ client .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
169+ backend .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
170+ c_fd = client .fileno ()
171+ b_fd = backend .fileno ()
172+ c2b_r , c2b_w = os .pipe ()
173+ b2c_r , b2c_w = os .pipe ()
174+ pipes [c_fd ] = (b_fd , c2b_r , c2b_w )
175+ pipes [b_fd ] = (c_fd , b2c_r , b2c_w )
176+ sockets [c_fd ] = client
177+ sockets [b_fd ] = backend
178+ poller .register (c_fd , select .POLLIN )
179+ poller .register (b_fd , select .POLLIN )
180+
181+ def _remove_fd (fd : int ) -> None :
182+ if fd not in pipes :
183+ return
184+ peer_fd , pipe_r , pipe_w = pipes .pop (fd )
185+ os .close (pipe_r )
186+ os .close (pipe_w )
187+ try :
188+ poller .unregister (fd )
189+ except (KeyError , OSError ):
190+ pass
191+ s = sockets .pop (fd , None )
192+ if s :
193+ try :
194+ s .close ()
195+ except OSError :
196+ pass
197+ # Also remove peer
198+ if peer_fd in pipes :
199+ p_peer , p_r , p_w = pipes .pop (peer_fd )
200+ os .close (p_r )
201+ os .close (p_w )
202+ try :
203+ poller .unregister (peer_fd )
204+ except (KeyError , OSError ):
205+ pass
206+ ps = sockets .pop (peer_fd , None )
207+ if ps :
208+ try :
209+ ps .close ()
210+ except OSError :
211+ pass
212+
213+ backend_addr = "::1" if af == socket .AF_INET6 else "127.0.0.1"
214+ _SPLICE_F_NONBLOCK = 0x02
215+
176216 try :
177217 while not stop .is_set ():
178- data = src .recv (65536 )
179- if not data :
218+ try :
219+ events = poller .poll (500 )
220+ except OSError :
180221 break
181- dst .sendall (data )
182- except OSError :
183- pass
222+ for fd , event in events :
223+ if fd == listener_fd :
224+ # Accept new connection
225+ try :
226+ client , _ = listener .accept ()
227+ except OSError :
228+ continue
229+ try :
230+ backend = socket .socket (af , socket .SOCK_STREAM )
231+ backend .connect ((backend_addr , real_port ))
232+ except OSError :
233+ client .close ()
234+ continue
235+ _add_pair (client , backend )
236+ continue
237+
238+ if fd not in pipes :
239+ continue
240+
241+ if event & (select .POLLERR | select .POLLNVAL ):
242+ _remove_fd (fd )
243+ continue
244+
245+ if event & (select .POLLIN | select .POLLHUP ):
246+ peer_fd , pipe_r , pipe_w = pipes [fd ]
247+ try :
248+ n = os .splice (fd , pipe_w , 65536 ,
249+ flags = _SPLICE_F_NONBLOCK )
250+ if n == 0 :
251+ _remove_fd (fd )
252+ continue
253+ while n > 0 :
254+ n -= os .splice (pipe_r , peer_fd , n )
255+ except BlockingIOError :
256+ pass
257+ except OSError :
258+ _remove_fd (fd )
184259 finally :
185- try :
186- dst .shutdown (socket .SHUT_WR )
187- except OSError :
188- pass
260+ for fd in list (pipes ):
261+ _remove_fd (fd )
262+ listener .close ()
189263
190264
191265def get_port_map (proxy : bool = False ) -> PortMap :
0 commit comments