@@ -3,12 +3,13 @@ pub mod pulsebeam {
33 include ! ( concat!( env!( "OUT_DIR" ) , "/pulsebeam.v1.rs" ) ) ;
44 }
55}
6- use moka:: sync :: Cache ;
6+ use moka:: future :: Cache ;
77pub use pulsebeam:: v1:: { self as rpc} ;
88use pulsebeam:: v1:: { IceServer , Message } ;
99use pulsebeam:: v1:: { PeerInfo , PrepareReq , PrepareResp , RecvReq , RecvResp , SendReq , SendResp } ;
1010use std:: sync:: Arc ;
1111use tokio:: time;
12+ use tracing:: warn;
1213use twirp:: async_trait:: async_trait;
1314
1415const SESSION_POLL_TIMEOUT : time:: Duration = time:: Duration :: from_secs ( 1200 ) ;
@@ -42,15 +43,18 @@ impl Server {
4243 }
4344
4445 #[ inline]
45- fn get ( & self , group_id : & str , peer_id : & str , conn_id : u32 ) -> Channel {
46+ async fn get ( & self , group_id : & str , peer_id : & str , conn_id : u32 ) -> Channel {
4647 let id = format ! ( "{}:{}:{}" , group_id, peer_id, conn_id) ;
4748 self . mailboxes
48- . get_with_by_ref ( & id, || flume:: bounded ( self . cfg . mailbox_capacity ) )
49+ . get_with_by_ref ( & id, async { flume:: bounded ( self . cfg . mailbox_capacity ) } )
50+ . await
4951 }
5052
5153 async fn recv_batch ( & self , src : & PeerInfo ) -> Vec < Message > {
52- let ( _, discovery_ch) = self . get ( & src. group_id , & src. peer_id , RESERVED_CONN_ID_DISCOVERY ) ;
53- let ( _, payload_ch) = self . get ( & src. group_id , & src. peer_id , src. conn_id ) ;
54+ let ( _, discovery_ch) = self
55+ . get ( & src. group_id , & src. peer_id , RESERVED_CONN_ID_DISCOVERY )
56+ . await ;
57+ let ( _, payload_ch) = self . get ( & src. group_id , & src. peer_id , src. conn_id ) . await ;
5458 let mut msgs = Vec :: new ( ) ;
5559 let mut poll_timeout = time:: interval_at (
5660 time:: Instant :: now ( ) + SESSION_POLL_TIMEOUT - SESSION_POLL_LATENCY_TOLERANCE ,
@@ -60,20 +64,22 @@ impl Server {
6064 loop {
6165 let mut msg: Option < Message > = None ;
6266 tokio:: select! {
63- Ok ( m ) = discovery_ch. recv_async( ) => {
64- msg = Some ( m ) ;
67+ res = discovery_ch. recv_async( ) => {
68+ msg = res . ok ( ) ;
6569 poll_timeout. reset( ) ;
6670 }
67- Ok ( m ) = payload_ch. recv_async( ) => {
68- msg = Some ( m ) ;
71+ res = payload_ch. recv_async( ) => {
72+ msg = res . ok ( ) ;
6973 poll_timeout. reset( ) ;
7074 }
7175 _ = poll_timeout. tick( ) => { }
7276 }
7377
78+ // TODO: deduplicate messages
7479 if let Some ( msg) = msg {
7580 msgs. push ( msg) ;
7681 } else {
82+ warn ! ( "senders have dropped" ) ;
7783 return msgs;
7884 }
7985 }
@@ -116,7 +122,7 @@ impl rpc::Tunnel for Server {
116122 . as_ref ( )
117123 . ok_or ( twirp:: invalid_argument ( "dst is required" ) ) ?;
118124
119- let ( ch, _) = self . get ( & dst. group_id , & dst. peer_id , dst. conn_id ) ;
125+ let ( ch, _) = self . get ( & dst. group_id , & dst. peer_id , dst. conn_id ) . await ;
120126 ch. send ( msg)
121127 . map_err ( |err| twirp:: internal ( err. to_string ( ) ) ) ?;
122128
@@ -254,31 +260,26 @@ mod test {
254260 }
255261
256262 #[ tokio:: test]
257- async fn recv_after_timeout ( ) {
263+ async fn recv_after_ttl ( ) {
258264 let ( s, peer1, peer2) = setup ( ) ;
259- let msgs = vec ! [ dummy_msg( peer1. clone( ) , peer2. clone( ) , 0 ) ] ;
260- s. send (
261- dummy_ctx ( ) ,
262- SendReq {
263- msg : Some ( msgs[ 0 ] . clone ( ) ) ,
264- } ,
265- )
266- . await
267- . unwrap ( ) ;
268-
269265 s. mailboxes . invalidate_all ( ) ;
270- let resp = time:: timeout (
271- time:: Duration :: from_millis ( 5 ) ,
272- s. recv (
273- dummy_ctx ( ) ,
274- RecvReq {
275- src : Some ( peer2. clone ( ) ) ,
276- } ,
277- ) ,
278- )
279- . await ;
266+ s. mailboxes . run_pending_tasks ( ) . await ;
267+ let sent = vec ! [ dummy_msg( peer1. clone( ) , peer2. clone( ) , 0 ) ] ;
268+ let msg = sent[ 0 ] . clone ( ) ;
269+
270+ let cloned = Arc :: clone ( & s) ;
271+ tokio:: spawn ( async move {
272+ cloned
273+ . send ( dummy_ctx ( ) , SendReq { msg : Some ( msg) } )
274+ . await
275+ . unwrap ( ) ;
276+ } ) ;
277+ let received = s. recv_batch ( & peer2) . await ;
278+ assert_msgs ( & received, & sent) ;
280279
281- // expect timeout to hit because there's no message
282- assert ! ( resp. is_err( ) ) ;
280+ // at this point, cache has been refreshed. So, it will stuck until it receives more
281+ // messages
282+ let res = time:: timeout ( time:: Duration :: from_millis ( 5 ) , s. recv_batch ( & peer2) ) . await ;
283+ assert ! ( res. is_err( ) ) ;
283284 }
284285}
0 commit comments