11use futures_lite:: StreamExt ;
22use std:: collections:: HashMap ;
3+ use std:: io;
34use std:: net:: { IpAddr , SocketAddr } ;
45use std:: sync:: Arc ;
56use std:: time:: Duration ;
@@ -10,7 +11,6 @@ use str0m::{
1011 media:: { Direction , MediaAdded , MediaKind , Mid , Pt } ,
1112 net:: { Protocol , Receive } ,
1213} ;
13- use tokio:: net:: UdpSocket ;
1414use tokio:: sync:: { Notify , mpsc} ;
1515use tokio:: time:: Instant ;
1616use tokio_stream:: StreamMap ;
@@ -74,17 +74,17 @@ struct TrackRequest {
7474 simulcast_layers : Option < Vec < SimulcastLayer > > ,
7575}
7676
77- pub struct AgentBuilder {
77+ pub struct AgentBuilder < S > {
7878 signaling : HttpSignalingClient ,
79- udp_socket : Option < UdpSocket > ,
79+ udp_socket : S ,
8080 tracks : Vec < TrackRequest > ,
8181}
8282
83- impl AgentBuilder {
84- pub fn new ( signaling : HttpSignalingClient ) -> Self {
83+ impl < S : UdpSocket > AgentBuilder < S > {
84+ pub fn new ( signaling : HttpSignalingClient , udp_socket : S ) -> AgentBuilder < S > {
8585 Self {
8686 signaling,
87- udp_socket : None ,
87+ udp_socket,
8888 tracks : Vec :: new ( ) ,
8989 }
9090 }
@@ -103,18 +103,8 @@ impl AgentBuilder {
103103 self
104104 }
105105
106- pub fn with_udp_socket ( mut self , socket : UdpSocket ) -> Self {
107- self . udp_socket = Some ( socket) ;
108- self
109- }
110-
111106 pub async fn join ( mut self , room_id : & str ) -> Result < Agent , AgentError > {
112- let socket = if let Some ( socket) = self . udp_socket {
113- socket
114- } else {
115- UdpSocket :: bind ( "0.0.0.0:0" ) . await ?
116- } ;
117- let port = socket. local_addr ( ) ?. port ( ) ;
107+ let port = self . udp_socket . local_addr ( ) ?. port ( ) ;
118108
119109 let local_ips: Vec < IpAddr > = if_addrs:: get_if_addrs ( ) ?
120110 . into_iter ( )
@@ -198,7 +188,7 @@ impl AgentBuilder {
198188 let actor = AgentActor {
199189 addr,
200190 rtc,
201- socket,
191+ socket : self . udp_socket ,
202192 buf : vec ! [ 0u8 ; 2048 ] ,
203193 event_tx,
204194 senders : StreamMap :: new ( ) ,
@@ -225,10 +215,6 @@ pub struct Agent {
225215}
226216
227217impl Agent {
228- pub fn builder ( signaling : HttpSignalingClient ) -> AgentBuilder {
229- AgentBuilder :: new ( signaling)
230- }
231-
232218 pub async fn next_event ( & mut self ) -> Option < AgentEvent > {
233219 self . events . recv ( ) . await
234220 }
@@ -239,10 +225,10 @@ impl Agent {
239225 }
240226}
241227
242- struct AgentActor {
228+ struct AgentActor < S > {
243229 addr : SocketAddr ,
244230 rtc : Rtc ,
245- socket : UdpSocket ,
231+ socket : S ,
246232 buf : Vec < u8 > ,
247233 event_tx : mpsc:: Sender < AgentEvent > ,
248234
@@ -252,7 +238,7 @@ struct AgentActor {
252238 shutdown : Arc < Notify > ,
253239}
254240
255- impl AgentActor {
241+ impl < S : UdpSocket > AgentActor < S > {
256242 async fn run ( mut self , medias : Vec < MediaAdded > ) {
257243 for media in medias {
258244 self . handle_media_added ( media) ;
@@ -368,3 +354,47 @@ impl AgentActor {
368354 let _ = self . event_tx . try_send ( event) ;
369355 }
370356}
357+
358+ pub trait UdpSocket : Send + Sync + ' static {
359+ fn try_send_to ( & self , buf : & [ u8 ] , target : SocketAddr ) -> io:: Result < usize > ;
360+ fn recv_from (
361+ & self ,
362+ buf : & mut [ u8 ] ,
363+ ) -> impl Future < Output = io:: Result < ( usize , SocketAddr ) > > + Send ;
364+ fn local_addr ( & self ) -> io:: Result < SocketAddr > ;
365+ }
366+
367+ impl UdpSocket for tokio:: net:: UdpSocket {
368+ fn try_send_to ( & self , buf : & [ u8 ] , target : SocketAddr ) -> io:: Result < usize > {
369+ self . try_send_to ( buf, target)
370+ }
371+
372+ fn recv_from (
373+ & self ,
374+ buf : & mut [ u8 ] ,
375+ ) -> impl Future < Output = io:: Result < ( usize , SocketAddr ) > > + Send {
376+ self . recv_from ( buf)
377+ }
378+
379+ fn local_addr ( & self ) -> io:: Result < SocketAddr > {
380+ self . local_addr ( )
381+ }
382+ }
383+
384+ #[ cfg( feature = "turmoil" ) ]
385+ impl UdpSocket for turmoil:: net:: UdpSocket {
386+ fn try_send_to ( & self , buf : & [ u8 ] , target : SocketAddr ) -> io:: Result < usize > {
387+ self . try_send_to ( buf, target)
388+ }
389+
390+ fn recv_from (
391+ & self ,
392+ buf : & mut [ u8 ] ,
393+ ) -> impl Future < Output = io:: Result < ( usize , SocketAddr ) > > + Send {
394+ self . recv_from ( buf)
395+ }
396+
397+ fn local_addr ( & self ) -> io:: Result < SocketAddr > {
398+ self . local_addr ( )
399+ }
400+ }
0 commit comments