@@ -20,24 +20,29 @@ use crate::lsps0;
2020use crate :: lsps1;
2121use crate :: lsps2;
2222use crate :: prelude:: { Vec , VecDeque } ;
23- use crate :: sync:: Mutex ;
23+ use crate :: sync:: { Arc , Mutex } ;
24+
25+ use core:: future:: Future ;
26+ use core:: task:: { Poll , Waker } ;
2427
2528pub ( crate ) struct EventQueue {
26- queue : Mutex < VecDeque < Event > > ,
29+ queue : Arc < Mutex < VecDeque < Event > > > ,
30+ waker : Arc < Mutex < Option < Waker > > > ,
2731 #[ cfg( feature = "std" ) ]
2832 condvar : std:: sync:: Condvar ,
2933}
3034
3135impl EventQueue {
3236 pub fn new ( ) -> Self {
33- let queue = Mutex :: new ( VecDeque :: new ( ) ) ;
37+ let queue = Arc :: new ( Mutex :: new ( VecDeque :: new ( ) ) ) ;
38+ let waker = Arc :: new ( Mutex :: new ( None ) ) ;
3439 #[ cfg( feature = "std" ) ]
3540 {
3641 let condvar = std:: sync:: Condvar :: new ( ) ;
37- Self { queue, condvar }
42+ Self { queue, waker , condvar }
3843 }
3944 #[ cfg( not( feature = "std" ) ) ]
40- Self { queue }
45+ Self { queue, waker }
4146 }
4247
4348 pub fn enqueue ( & self , event : Event ) {
@@ -46,6 +51,9 @@ impl EventQueue {
4651 queue. push_back ( event) ;
4752 }
4853
54+ if let Some ( waker) = self . waker . lock ( ) . unwrap ( ) . take ( ) {
55+ waker. wake ( ) ;
56+ }
4957 #[ cfg( feature = "std" ) ]
5058 self . condvar . notify_one ( ) ;
5159 }
@@ -54,6 +62,10 @@ impl EventQueue {
5462 self . queue . lock ( ) . unwrap ( ) . pop_front ( )
5563 }
5664
65+ pub async fn next_event_async ( & self ) -> Event {
66+ EventFuture { event_queue : Arc :: clone ( & self . queue ) , waker : Arc :: clone ( & self . waker ) } . await
67+ }
68+
5769 #[ cfg( feature = "std" ) ]
5870 pub fn wait_next_event ( & self ) -> Event {
5971 let mut queue =
@@ -65,6 +77,10 @@ impl EventQueue {
6577 drop ( queue) ;
6678
6779 if should_notify {
80+ if let Some ( waker) = self . waker . lock ( ) . unwrap ( ) . take ( ) {
81+ waker. wake ( ) ;
82+ }
83+
6884 self . condvar . notify_one ( ) ;
6985 }
7086
@@ -92,3 +108,132 @@ pub enum Event {
92108 /// An LSPS2 (JIT Channel) server event.
93109 LSPS2Service ( lsps2:: event:: LSPS2ServiceEvent ) ,
94110}
111+
112+ struct EventFuture {
113+ event_queue : Arc < Mutex < VecDeque < Event > > > ,
114+ waker : Arc < Mutex < Option < Waker > > > ,
115+ }
116+
117+ impl Future for EventFuture {
118+ type Output = Event ;
119+
120+ fn poll (
121+ self : core:: pin:: Pin < & mut Self > , cx : & mut core:: task:: Context < ' _ > ,
122+ ) -> core:: task:: Poll < Self :: Output > {
123+ if let Some ( event) = self . event_queue . lock ( ) . unwrap ( ) . pop_front ( ) {
124+ Poll :: Ready ( event)
125+ } else {
126+ * self . waker . lock ( ) . unwrap ( ) = Some ( cx. waker ( ) . clone ( ) ) ;
127+ Poll :: Pending
128+ }
129+ }
130+ }
131+
132+ #[ cfg( test) ]
133+ mod tests {
134+ use super :: * ;
135+ use crate :: lsps0:: event:: LSPS0ClientEvent ;
136+ use bitcoin:: secp256k1:: { PublicKey , Secp256k1 , SecretKey } ;
137+
138+ #[ tokio:: test]
139+ #[ cfg( feature = "std" ) ]
140+ async fn event_queue_works ( ) {
141+ use core:: sync:: atomic:: { AtomicU16 , Ordering } ;
142+ use std:: sync:: Arc ;
143+ use std:: time:: Duration ;
144+
145+ let event_queue = Arc :: new ( EventQueue :: new ( ) ) ;
146+ assert_eq ! ( event_queue. next_event( ) , None ) ;
147+
148+ let secp_ctx = Secp256k1 :: new ( ) ;
149+ let counterparty_node_id =
150+ PublicKey :: from_secret_key ( & secp_ctx, & SecretKey :: from_slice ( & [ 42 ; 32 ] ) . unwrap ( ) ) ;
151+ let expected_event = Event :: LSPS0Client ( LSPS0ClientEvent :: ListProtocolsResponse {
152+ counterparty_node_id,
153+ protocols : Vec :: new ( ) ,
154+ } ) ;
155+
156+ for _ in 0 ..3 {
157+ event_queue. enqueue ( expected_event. clone ( ) ) ;
158+ }
159+
160+ assert_eq ! ( event_queue. wait_next_event( ) , expected_event) ;
161+ assert_eq ! ( event_queue. next_event_async( ) . await , expected_event) ;
162+ assert_eq ! ( event_queue. next_event( ) , Some ( expected_event. clone( ) ) ) ;
163+ assert_eq ! ( event_queue. next_event( ) , None ) ;
164+
165+ // Check `next_event_async` won't return if the queue is empty and always rather timeout.
166+ tokio:: select! {
167+ _ = tokio:: time:: sleep( Duration :: from_millis( 10 ) ) => {
168+ // Timeout
169+ }
170+ _ = event_queue. next_event_async( ) => {
171+ panic!( ) ;
172+ }
173+ }
174+ assert_eq ! ( event_queue. next_event( ) , None ) ;
175+
176+ // Check we get the expected number of events when polling/enqueuing concurrently.
177+ let enqueued_events = AtomicU16 :: new ( 0 ) ;
178+ let received_events = AtomicU16 :: new ( 0 ) ;
179+ let mut delayed_enqueue = false ;
180+
181+ for _ in 0 ..25 {
182+ event_queue. enqueue ( expected_event. clone ( ) ) ;
183+ enqueued_events. fetch_add ( 1 , Ordering :: SeqCst ) ;
184+ }
185+
186+ loop {
187+ tokio:: select! {
188+ _ = tokio:: time:: sleep( Duration :: from_millis( 10 ) ) , if !delayed_enqueue => {
189+ event_queue. enqueue( expected_event. clone( ) ) ;
190+ enqueued_events. fetch_add( 1 , Ordering :: SeqCst ) ;
191+ delayed_enqueue = true ;
192+ }
193+ e = event_queue. next_event_async( ) => {
194+ assert_eq!( e, expected_event) ;
195+ received_events. fetch_add( 1 , Ordering :: SeqCst ) ;
196+
197+ event_queue. enqueue( expected_event. clone( ) ) ;
198+ enqueued_events. fetch_add( 1 , Ordering :: SeqCst ) ;
199+ }
200+ e = event_queue. next_event_async( ) => {
201+ assert_eq!( e, expected_event) ;
202+ received_events. fetch_add( 1 , Ordering :: SeqCst ) ;
203+ }
204+ }
205+
206+ if delayed_enqueue
207+ && received_events. load ( Ordering :: SeqCst ) == enqueued_events. load ( Ordering :: SeqCst )
208+ {
209+ break ;
210+ }
211+ }
212+ assert_eq ! ( event_queue. next_event( ) , None ) ;
213+
214+ // Check we operate correctly, even when mixing and matching blocking and async API calls.
215+ let ( tx, mut rx) = tokio:: sync:: watch:: channel ( ( ) ) ;
216+ let thread_queue = Arc :: clone ( & event_queue) ;
217+ let thread_event = expected_event. clone ( ) ;
218+ std:: thread:: spawn ( move || {
219+ let e = thread_queue. wait_next_event ( ) ;
220+ assert_eq ! ( e, thread_event) ;
221+ tx. send ( ( ) ) . unwrap ( ) ;
222+ } ) ;
223+
224+ let thread_queue = Arc :: clone ( & event_queue) ;
225+ let thread_event = expected_event. clone ( ) ;
226+ std:: thread:: spawn ( move || {
227+ // Sleep a bit before we enqueue the events everybody is waiting for.
228+ std:: thread:: sleep ( Duration :: from_millis ( 20 ) ) ;
229+ thread_queue. enqueue ( thread_event. clone ( ) ) ;
230+ thread_queue. enqueue ( thread_event. clone ( ) ) ;
231+ } ) ;
232+
233+ let e = event_queue. next_event_async ( ) . await ;
234+ assert_eq ! ( e, expected_event. clone( ) ) ;
235+
236+ rx. changed ( ) . await . unwrap ( ) ;
237+ assert_eq ! ( event_queue. next_event( ) , None ) ;
238+ }
239+ }
0 commit comments