1- use super :: { BATCH_SIZE , CHUNK_SIZE , RecvPacketBatch , SendPacketBatch } ;
1+ use super :: { BATCH_SIZE , RecvPacketBatch , SendPacketBatch } ;
22use crate :: net:: Transport ;
33use bytes:: { Buf , BufMut , Bytes , BytesMut } ;
44use dashmap:: DashMap ;
55use std:: sync:: Arc ;
6- use std:: { io, net:: SocketAddr } ;
6+ use std:: {
7+ io,
8+ net:: { IpAddr , SocketAddr } ,
9+ time:: Duration ,
10+ } ;
711use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
812use tokio:: net:: { TcpListener , TcpStream } ;
13+ use tokio_util:: sync:: CancellationToken ;
14+
15+ const MAX_FRAME_SIZE : usize = 1500 ;
16+ const MAX_CONNECTIONS : usize = 10_000 ;
17+ const MAX_CONNS_PER_IP : usize = 20 ;
18+ const READ_TIMEOUT : Duration = Duration :: from_secs ( 30 ) ;
19+
20+ /// Shared state for connection management
21+ type ConnMap = Arc < DashMap < SocketAddr , ( async_channel:: Sender < Bytes > , CancellationToken ) > > ;
922
10- /// Creates a TCP Reader and Writer pair.
1123pub async fn bind (
1224 addr : SocketAddr ,
1325 external_addr : Option < SocketAddr > ,
1426) -> io:: Result < ( TcpTransportReader , TcpTransportWriter ) > {
1527 let listener = TcpListener :: bind ( addr) . await ?;
1628 let local_addr = external_addr. unwrap_or ( listener. local_addr ( ) ?) ;
1729
18- let ( packet_tx, packet_rx) = async_channel:: bounded ( BATCH_SIZE * 64 ) ;
19- let conns = Arc :: new ( DashMap :: new ( ) ) ;
30+ let ( packet_tx, packet_rx) = async_channel:: bounded ( BATCH_SIZE * 128 ) ;
31+ let conns: ConnMap = Arc :: new ( DashMap :: new ( ) ) ;
32+ let ip_counts = Arc :: new ( DashMap :: < IpAddr , usize > :: new ( ) ) ;
33+
2034 let readable_notifier = Arc :: new ( tokio:: sync:: Notify :: new ( ) ) ;
2135 let writable_notifier = Arc :: new ( tokio:: sync:: Notify :: new ( ) ) ;
36+ let conn_semaphore = Arc :: new ( tokio:: sync:: Semaphore :: new ( MAX_CONNECTIONS ) ) ;
2237
2338 let conns_clone = conns. clone ( ) ;
2439 let r_notify = readable_notifier. clone ( ) ;
2540 let w_notify = writable_notifier. clone ( ) ;
41+ let ip_counts_clone = ip_counts. clone ( ) ;
42+ let semaphore_clone = conn_semaphore. clone ( ) ;
2643
27- // Passive Listener Task (RFC 6544)
44+ // Passive Listener Task
2845 tokio:: spawn ( async move {
2946 while let Ok ( ( stream, peer_addr) ) = listener. accept ( ) . await {
47+ let ip = peer_addr. ip ( ) ;
48+
49+ // 1. Global Connection Limit
50+ let permit = match semaphore_clone. clone ( ) . try_acquire_owned ( ) {
51+ Ok ( p) => p,
52+ Err ( _) => {
53+ tracing:: warn!( %peer_addr, "Rejecting: Global limit reached" ) ;
54+ continue ;
55+ }
56+ } ;
57+
58+ // 2. Per-IP Limit
59+ {
60+ let mut count = ip_counts_clone. entry ( ip) . or_insert ( 0 ) ;
61+ if * count >= MAX_CONNS_PER_IP {
62+ tracing:: warn!( %peer_addr, "Rejecting: Per-IP limit reached" ) ;
63+ continue ;
64+ }
65+ * count += 1 ;
66+ }
67+
3068 let _ = stream. set_nodelay ( true ) ;
31- let _ = stream. set_linger ( None ) ;
3269
3370 handle_new_connection (
3471 stream,
3572 peer_addr,
3673 local_addr,
3774 packet_tx. clone ( ) ,
3875 conns_clone. clone ( ) ,
76+ ip_counts_clone. clone ( ) ,
3977 r_notify. clone ( ) ,
4078 w_notify. clone ( ) ,
79+ permit,
4180 ) ;
4281 w_notify. notify_waiters ( ) ;
4382 }
@@ -47,6 +86,7 @@ pub async fn bind(
4786 local_addr,
4887 packet_rx,
4988 readable_notifier,
89+ conns : conns. clone ( ) , // Reader now has access to conns
5090 } ;
5191
5292 let writer = TcpTransportWriter {
@@ -62,14 +102,25 @@ pub struct TcpTransportReader {
62102 packet_rx : async_channel:: Receiver < RecvPacketBatch > ,
63103 readable_notifier : Arc < tokio:: sync:: Notify > ,
64104 local_addr : SocketAddr ,
105+ /// Shared connection map to allow forced disconnects
106+ conns : ConnMap ,
65107}
66108
67109impl TcpTransportReader {
110+ /// Closes the connection for a specific peer.
111+ /// Useful when high-level demux or auth fails.
112+ pub fn close_peer ( & self , peer_addr : & SocketAddr ) {
113+ if let Some ( ( _, ( tx, cancel) ) ) = self . conns . remove ( peer_addr) {
114+ cancel. cancel ( ) ; // Stops the background reader task
115+ tx. close ( ) ; // Stops the background writer task
116+ tracing:: info!( %peer_addr, "Reader forced connection close" ) ;
117+ }
118+ }
119+
68120 pub fn local_addr ( & self ) -> SocketAddr {
69121 self . local_addr
70122 }
71123
72- /// Waits until the socket is readable.
73124 pub async fn readable ( & self ) -> io:: Result < ( ) > {
74125 loop {
75126 if !self . packet_rx . is_empty ( ) {
@@ -83,8 +134,6 @@ impl TcpTransportReader {
83134 }
84135 }
85136
86- /// Pulls packets from the internal channel into the provided vector.
87- /// Takes `&mut self` as requested for the Reader API.
88137 #[ inline]
89138 pub fn try_recv_batch ( & mut self , out : & mut Vec < RecvPacketBatch > ) -> io:: Result < ( ) > {
90139 let mut count = 0 ;
@@ -105,7 +154,7 @@ impl TcpTransportReader {
105154#[ derive( Clone ) ]
106155pub struct TcpTransportWriter {
107156 local_addr : SocketAddr ,
108- conns : Arc < DashMap < SocketAddr , async_channel :: Sender < Bytes > > > ,
157+ conns : ConnMap ,
109158 writable_notifier : Arc < tokio:: sync:: Notify > ,
110159}
111160
@@ -126,7 +175,7 @@ impl TcpTransportWriter {
126175 }
127176 let mut any_available = false ;
128177 for c in self . conns . iter ( ) {
129- if !c. is_full ( ) {
178+ if !c. value ( ) . 0 . is_full ( ) {
130179 any_available = true ;
131180 break ;
132181 }
@@ -135,99 +184,112 @@ impl TcpTransportWriter {
135184 return Ok ( ( ) ) ;
136185 }
137186 let wait = self . writable_notifier . notified ( ) ;
138-
139- // Re-check logic to prevent race conditions
140- let mut any_available = false ;
141- for c in self . conns . iter ( ) {
142- if !c. is_full ( ) {
143- any_available = true ;
144- break ;
145- }
146- }
147- if any_available {
148- return Ok ( ( ) ) ;
149- }
150187 wait. await ;
151188 }
152189 }
153190
154191 #[ inline]
155192 pub fn try_send_batch ( & self , batch : & SendPacketBatch ) -> io:: Result < bool > {
156- let Some ( peer_tx) = self . conns . get ( & batch. dst ) else {
157- // If the peer is gone, we drop the packet (consistent with UDP behavior)
193+ let Some ( peer_entry) = self . conns . get ( & batch. dst ) else {
158194 return Ok ( true ) ;
159195 } ;
196+ let ( peer_tx, _) = peer_entry. value ( ) ;
160197
161- let required_slots = ( batch. buf . len ( ) + batch . segment_size - 1 ) / batch. segment_size ;
198+ let required_slots = batch. buf . len ( ) . div_ceil ( batch. segment_size ) ;
162199 if peer_tx. capacity ( ) . unwrap ( ) - peer_tx. len ( ) < required_slots {
163200 return Ok ( false ) ;
164201 }
165202
166203 let mut offset = 0 ;
167- let total_len = batch. buf . len ( ) ;
168- while offset < total_len {
169- let end = std:: cmp:: min ( offset + batch. segment_size , total_len) ;
170- let segment = & batch. buf [ offset..end] ;
171- let _ = peer_tx. try_send ( Bytes :: copy_from_slice ( segment) ) ;
204+ while offset < batch. buf . len ( ) {
205+ let end = std:: cmp:: min ( offset + batch. segment_size , batch. buf . len ( ) ) ;
206+ let _ = peer_tx. try_send ( Bytes :: copy_from_slice ( & batch. buf [ offset..end] ) ) ;
172207 offset = end;
173208 }
174209 Ok ( true )
175210 }
176211}
177212
178- /// Background connection handler
179213fn handle_new_connection (
180214 stream : TcpStream ,
181215 peer_addr : SocketAddr ,
182216 local_addr : SocketAddr ,
183217 packet_tx : async_channel:: Sender < RecvPacketBatch > ,
184- conns : Arc < DashMap < SocketAddr , async_channel:: Sender < Bytes > > > ,
218+ conns : ConnMap ,
219+ ip_counts : Arc < DashMap < IpAddr , usize > > ,
185220 r_notify : Arc < tokio:: sync:: Notify > ,
186221 w_notify : Arc < tokio:: sync:: Notify > ,
222+ permit : tokio:: sync:: OwnedSemaphorePermit ,
187223) {
188- let ( send_tx, send_rx) = async_channel:: bounded :: < Bytes > ( 8192 ) ;
189- conns . insert ( peer_addr , send_tx ) ;
224+ let ( send_tx, send_rx) = async_channel:: bounded :: < Bytes > ( 1024 ) ;
225+ let cancel_token = CancellationToken :: new ( ) ;
190226
191- let ( mut reader, writer) = stream. into_split ( ) ;
227+ conns. insert ( peer_addr, ( send_tx, cancel_token. clone ( ) ) ) ;
228+
229+ let ( mut tcp_reader, tcp_writer) = stream. into_split ( ) ;
230+ let peer_ip = peer_addr. ip ( ) ;
192231
193232 // Task: Receiver (RFC 4571 Un-framing)
233+ let r_cancel = cancel_token. clone ( ) ;
234+ let r_conns = conns. clone ( ) ;
235+ let r_ip_counts = ip_counts. clone ( ) ;
194236 tokio:: spawn ( async move {
195- let mut recv_buf = BytesMut :: with_capacity ( CHUNK_SIZE * 4 ) ;
196- while let Ok ( n) = reader. read_buf ( & mut recv_buf) . await {
197- if n == 0 {
198- break ;
199- }
200- let mut added = false ;
201- while recv_buf. len ( ) >= 2 {
202- let len = u16:: from_be_bytes ( [ recv_buf[ 0 ] , recv_buf[ 1 ] ] ) as usize ;
203- if recv_buf. len ( ) < 2 + len {
204- break ;
237+ // Guard to release semaphore and cleanup DashMap on task exit
238+ let _guard = ( permit, r_cancel) ;
239+ let mut recv_buf = BytesMut :: with_capacity ( MAX_FRAME_SIZE + 2 ) ;
240+
241+ loop {
242+ tokio:: select! {
243+ _ = _guard. 1 . cancelled( ) => break ,
244+ res = tokio:: time:: timeout( READ_TIMEOUT , tcp_reader. read_buf( & mut recv_buf) ) => {
245+ let n = match res {
246+ Ok ( Ok ( n) ) if n > 0 => n,
247+ _ => break , // Timeout, Error, or EOF
248+ } ;
249+
250+ while recv_buf. len( ) >= 2 {
251+ let len = u16 :: from_be_bytes( [ recv_buf[ 0 ] , recv_buf[ 1 ] ] ) as usize ;
252+
253+ if len > MAX_FRAME_SIZE || len == 0 {
254+ tracing:: warn!( %peer_addr, len, "Invalid TCP frame size, dropping connection" ) ;
255+ return ;
256+ }
257+
258+ if recv_buf. len( ) < 2 + len { break ; }
259+
260+ recv_buf. advance( 2 ) ;
261+ let data = recv_buf. split_to( len) . freeze( ) ;
262+
263+ // Use try_send to prevent reader task from blocking if SFU logic lags
264+ if let Err ( _) = packet_tx. try_send( RecvPacketBatch {
265+ src: peer_addr,
266+ dst: local_addr,
267+ buf: data,
268+ stride: len,
269+ len,
270+ transport: Transport :: Tcp ,
271+ } ) {
272+ tracing:: debug!( "TCP packet dropped: Global queue full" ) ;
273+ } else {
274+ r_notify. notify_waiters( ) ;
275+ }
276+ }
205277 }
206- recv_buf. advance ( 2 ) ;
207- let data = recv_buf. split_to ( len) . freeze ( ) ;
208- let _ = packet_tx
209- . send ( RecvPacketBatch {
210- src : peer_addr,
211- dst : local_addr,
212- buf : data,
213- stride : len,
214- len,
215- transport : Transport :: Tcp ,
216- } )
217- . await ;
218- added = true ;
219- }
220- if added {
221- r_notify. notify_waiters ( ) ;
222278 }
223279 }
224- conns. remove ( & peer_addr) ;
280+
281+ // Final cleanup
282+ r_conns. remove ( & peer_addr) ;
283+ if let Some ( mut count) = r_ip_counts. get_mut ( & peer_ip) {
284+ * count = count. saturating_sub ( 1 ) ;
285+ }
225286 } ) ;
226287
227- // Task: Sender (RFC 4571 Framing + Syscall Batching)
288+ // Task: Sender
228289 tokio:: spawn ( async move {
229- let mut write_buf = Vec :: with_capacity ( CHUNK_SIZE * 2 ) ;
230- let mut writer = writer;
290+ let mut write_buf = Vec :: with_capacity ( MAX_FRAME_SIZE + 2 ) ;
291+ let mut writer = tcp_writer;
292+
231293 while let Ok ( first) = send_rx. recv ( ) . await {
232294 write_buf. clear ( ) ;
233295 write_buf. put_u16 ( first. len ( ) as u16 ) ;
@@ -236,7 +298,7 @@ fn handle_new_connection(
236298 while let Ok ( next) = send_rx. try_recv ( ) {
237299 write_buf. put_u16 ( next. len ( ) as u16 ) ;
238300 write_buf. put_slice ( & next) ;
239- if write_buf. len ( ) > 65535 {
301+ if write_buf. len ( ) > 16384 {
240302 break ;
241303 }
242304 }
@@ -248,6 +310,7 @@ fn handle_new_connection(
248310 }
249311 } ) ;
250312}
313+
251314#[ cfg( test) ]
252315mod tests {
253316 use super :: * ;
0 commit comments