55use std:: { io, net:: SocketAddr } ;
66
77use bytes:: Bytes ;
8- use quinn_udp:: RecvMeta ;
9-
10- use crate :: net:: { tcp:: TcpTransport , udp:: UdpTransport } ;
118
129pub const BATCH_SIZE : usize = quinn_udp:: BATCH_SIZE ;
1310// Fit allocator page size and Linux GRO limit
@@ -84,49 +81,6 @@ impl<'a> Iterator for RecvPacketBatchIter<'a> {
8481 }
8582}
8683
87- pub struct RecvPacketBatcher {
88- meta : [ RecvMeta ; BATCH_SIZE ] ,
89- batch_buffer : Vec < u8 > ,
90- }
91-
92- impl RecvPacketBatcher {
93- pub fn new ( ) -> Self {
94- Self {
95- meta : [ RecvMeta :: default ( ) ; BATCH_SIZE ] ,
96- batch_buffer : Vec :: with_capacity ( BATCH_SIZE * CHUNK_SIZE ) ,
97- }
98- }
99-
100- fn collect ( & mut self , local_addr : SocketAddr , count : usize , out : & mut Vec < RecvPacketBatch > ) {
101- let new_buffer = Vec :: with_capacity ( BATCH_SIZE * CHUNK_SIZE ) ;
102- let filled_buffer = std:: mem:: replace ( & mut self . batch_buffer , new_buffer) ;
103- let master_bytes = Bytes :: from ( filled_buffer) ;
104-
105- for i in 0 ..count {
106- let m = & self . meta [ i] ;
107-
108- let start = i * CHUNK_SIZE ;
109- let end = start + m. len ;
110-
111- // Safety: Ensure we don't slice past the buffer (e.g., if kernel lied about len)
112- if end > master_bytes. len ( ) {
113- continue ;
114- }
115-
116- let buf = master_bytes. slice ( start..end) ;
117-
118- out. push ( RecvPacketBatch {
119- src : m. addr ,
120- dst : local_addr,
121- buf,
122- stride : m. stride ,
123- len : m. len ,
124- transport : Transport :: Udp ,
125- } ) ;
126- }
127- }
128- }
129-
13084/// A packet to send.
13185#[ derive( Debug , Clone ) ]
13286pub struct SendPacket {
@@ -141,26 +95,44 @@ pub struct SendPacketBatch<'a> {
14195 pub segment_size : usize ,
14296}
14397
144- /// UnifiedSocket enum for different transport types
145- pub enum UnifiedSocket {
146- Udp ( UdpTransport ) ,
147- Tcp ( TcpTransport ) ,
98+ /// Binds a socket to the given address and transport type.
99+ pub async fn bind (
100+ addr : SocketAddr ,
101+ transport : Transport ,
102+ external_addr : Option < SocketAddr > ,
103+ ) -> io:: Result < ( UnifiedSocketReader , UnifiedSocketWriter ) > {
104+ let socks = match transport {
105+ Transport :: Udp => {
106+ let ( reader, writer) = udp:: bind ( addr, external_addr) ?;
107+ (
108+ UnifiedSocketReader :: Udp ( Box :: new ( reader) ) ,
109+ UnifiedSocketWriter :: Udp ( writer) ,
110+ )
111+ }
112+
113+ Transport :: Tcp => {
114+ let ( reader, writer) = tcp:: bind ( addr, external_addr) . await ?;
115+ (
116+ UnifiedSocketReader :: Tcp ( reader) ,
117+ UnifiedSocketWriter :: Tcp ( writer) ,
118+ )
119+ }
120+ _ => todo ! ( ) ,
121+ } ;
122+ tracing:: debug!( "bound to {addr} ({transport:?})" ) ;
123+ Ok ( socks)
148124}
149125
150- impl UnifiedSocket {
151- /// Binds a socket to the given address and transport type.
152- pub async fn bind (
153- addr : SocketAddr ,
154- transport : Transport ,
155- external_addr : Option < SocketAddr > ,
156- ) -> io:: Result < Self > {
157- let sock = match transport {
158- Transport :: Udp => Self :: Udp ( UdpTransport :: bind ( addr, external_addr) ?) ,
159- Transport :: Tcp => Self :: Tcp ( TcpTransport :: bind ( addr, external_addr) . await ?) ,
160- _ => todo ! ( ) ,
161- } ;
162- tracing:: debug!( "bound to {addr} ({transport:?})" ) ;
163- Ok ( sock)
126+ pub enum UnifiedSocketReader {
127+ Udp ( Box < udp:: UdpTransportReader > ) ,
128+ Tcp ( tcp:: TcpTransportReader ) ,
129+ }
130+
131+ impl UnifiedSocketReader {
132+ pub fn close_peer ( & mut self , peer_addr : & SocketAddr ) {
133+ if let Self :: Tcp ( inner) = self {
134+ inner. close_peer ( peer_addr) ;
135+ }
164136 }
165137
166138 pub fn local_addr ( & self ) -> SocketAddr {
@@ -170,13 +142,6 @@ impl UnifiedSocket {
170142 }
171143 }
172144
173- pub fn max_gso_segments ( & self ) -> usize {
174- match self {
175- Self :: Udp ( inner) => inner. max_gso_segments ( ) ,
176- Self :: Tcp ( inner) => inner. max_gso_segments ( ) ,
177- }
178- }
179-
180145 /// Waits until the socket is readable.
181146 #[ inline]
182147 pub async fn readable ( & self ) -> io:: Result < ( ) > {
@@ -186,25 +151,36 @@ impl UnifiedSocket {
186151 }
187152 }
188153
189- /// Waits until the socket is writable .
154+ /// Receives a batch of packets into pre-allocated buffers .
190155 #[ inline]
191- pub async fn writable ( & self ) -> io:: Result < ( ) > {
156+ pub fn try_recv_batch ( & mut self , packets : & mut Vec < RecvPacketBatch > ) -> std :: io:: Result < ( ) > {
192157 match self {
193- Self :: Udp ( inner) => inner. writable ( ) . await ,
194- Self :: Tcp ( inner) => inner. writable ( ) . await ,
158+ Self :: Udp ( inner) => inner. try_recv_batch ( packets ) ,
159+ Self :: Tcp ( inner) => inner. try_recv_batch ( packets ) ,
195160 }
196161 }
162+ }
197163
198- /// Receives a batch of packets into pre-allocated buffers.
164+ #[ derive( Clone ) ]
165+ pub enum UnifiedSocketWriter {
166+ Udp ( udp:: UdpTransportWriter ) ,
167+ Tcp ( tcp:: TcpTransportWriter ) ,
168+ }
169+
170+ impl UnifiedSocketWriter {
171+ pub fn max_gso_segments ( & self ) -> usize {
172+ match self {
173+ Self :: Udp ( inner) => inner. max_gso_segments ( ) ,
174+ Self :: Tcp ( inner) => inner. max_gso_segments ( ) ,
175+ }
176+ }
177+
178+ /// Waits until the socket is writable.
199179 #[ inline]
200- pub fn try_recv_batch (
201- & self ,
202- batch : & mut RecvPacketBatcher ,
203- packets : & mut Vec < RecvPacketBatch > ,
204- ) -> std:: io:: Result < ( ) > {
180+ pub async fn writable ( & self ) -> io:: Result < ( ) > {
205181 match self {
206- Self :: Udp ( inner) => inner. try_recv_batch ( batch , packets ) ,
207- Self :: Tcp ( inner) => inner. try_recv_batch ( packets ) ,
182+ Self :: Udp ( inner) => inner. writable ( ) . await ,
183+ Self :: Tcp ( inner) => inner. writable ( ) . await ,
208184 }
209185 }
210186
@@ -223,6 +199,13 @@ impl UnifiedSocket {
223199 Self :: Tcp ( _) => Transport :: Tcp ,
224200 }
225201 }
202+
203+ pub fn local_addr ( & self ) -> SocketAddr {
204+ match self {
205+ Self :: Udp ( inner) => inner. local_addr ( ) ,
206+ Self :: Tcp ( inner) => inner. local_addr ( ) ,
207+ }
208+ }
226209}
227210
228211fn fmt_bytes ( b : usize ) -> String {
@@ -250,13 +233,14 @@ mod tests {
250233
251234 async fn test_transport ( transport_type : Transport ) {
252235 let bind_addr: SocketAddr = "127.0.0.1:0" . parse ( ) . unwrap ( ) ;
253- let mut server = UnifiedSocket :: bind ( bind_addr, transport_type, None )
254- . await
255- . unwrap ( ) ;
256- let actual_server_addr = server. local_addr ( ) ;
257236
258- // --- 1. External Client Setup ---
259- // (We use raw sockets here to simulate an external browser/client)
237+ // 1. Bind now returns a split Reader and Writer
238+ let ( mut reader, writer) = bind ( bind_addr, transport_type, None ) . await . unwrap ( ) ;
239+
240+ // We can get the address from either, but writer is usually the "identity"
241+ let actual_server_addr = writer. local_addr ( ) ;
242+
243+ // --- 2. External Client Setup ---
260244 let mut tcp_client: Option < TcpStream > = None ;
261245 let mut udp_client: Option < UdpSocket > = None ;
262246
@@ -267,8 +251,7 @@ mod tests {
267251 udp_client = Some ( UdpSocket :: bind ( "127.0.0.1:0" ) . await . unwrap ( ) ) ;
268252 }
269253
270- // --- 2. Handshake: Client -> Server ---
271- // This allows the Server to discover the client's ephemeral port
254+ // --- 3. Handshake: Client -> Server ---
272255 let handshake_payload = b"hello-sfu" ;
273256 if let Some ( ref mut tcp) = tcp_client {
274257 let mut buf = Vec :: new ( ) ;
@@ -284,20 +267,23 @@ mod tests {
284267 . unwrap ( ) ;
285268 }
286269
287- // Server: Wait for handshake
288- server. readable ( ) . await . unwrap ( ) ;
289- let mut batcher = RecvPacketBatcher :: new ( ) ;
270+ // Server: Wait for handshake using the Reader
271+ reader. readable ( ) . await . unwrap ( ) ;
272+
273+ // Note: RecvPacketBatcher is now internal to the reader and not seen here
290274 let mut out = Vec :: new ( ) ;
291275
292276 // Retry loop for UDP loopback jitter
293277 let remote_peer_addr = loop {
294- if server. try_recv_batch ( & mut batcher, & mut out) . is_ok ( ) && !out. is_empty ( ) {
278+ out. clear ( ) ;
279+ // try_recv_batch now takes mut self and handles its own batcher
280+ if reader. try_recv_batch ( & mut out) . is_ok ( ) && !out. is_empty ( ) {
295281 break out[ 0 ] . src ;
296282 }
297283 tokio:: time:: sleep ( Duration :: from_millis ( 10 ) ) . await ;
298284 } ;
299285
300- // --- 3 . Data Transfer: Server -> Client ---
286+ // --- 4 . Data Transfer: Server -> Client ---
301287 let num_packets = 100 ;
302288 let packet_payload = b"important-media-data" ;
303289
@@ -328,22 +314,27 @@ mod tests {
328314 count
329315 } ) ;
330316
331- // Server: Send packets using the unified interface
317+ // Server: Send packets using the Writer
318+ // We can even clone the writer to show multi-owner capability
319+ let writer_tx = writer. clone ( ) ;
332320 let mut sent = 0 ;
333321 while sent < num_packets {
334- server . writable ( ) . await . unwrap ( ) ;
322+ writer_tx . writable ( ) . await . unwrap ( ) ;
335323 let batch = SendPacketBatch {
336324 dst : remote_peer_addr,
337325 buf : packet_payload,
338326 segment_size : packet_payload. len ( ) ,
339327 } ;
340328
341- match server . try_send_batch ( & batch) {
329+ match writer_tx . try_send_batch ( & batch) {
342330 Ok ( true ) => sent += 1 ,
343- Ok ( false ) | Err ( _ ) => {
331+ Ok ( false ) => {
344332 // Handle backpressure/WouldBlock
345333 tokio:: task:: yield_now ( ) . await ;
346334 }
335+ Err ( e) => {
336+ panic ! ( "Send failed: {e}" ) ;
337+ }
347338 }
348339 }
349340
0 commit comments