1515
1616use alloc:: sync:: Arc ;
1717use core:: mem;
18- use crate :: sync:: { Condvar , Mutex } ;
18+ use crate :: sync:: { Condvar , Mutex , MutexGuard } ;
1919
2020use crate :: prelude:: * ;
2121
@@ -33,6 +33,20 @@ pub(crate) struct Notifier {
3333 condvar : Condvar ,
3434}
3535
36+ macro_rules! check_woken {
37+ ( $guard: expr, $retval: expr) => { {
38+ if $guard. 0 {
39+ $guard. 0 = false ;
40+ if $guard. 1 . as_ref( ) . map( |l| l. lock( ) . unwrap( ) . complete) . unwrap_or( false ) {
41+ // If we're about to return as woken, and the future state is marked complete, wipe
42+ // the future state and let the next future wait until we get a new notify.
43+ $guard. 1 . take( ) ;
44+ }
45+ return $retval;
46+ }
47+ } }
48+ }
49+
3650impl Notifier {
3751 pub ( crate ) fn new ( ) -> Self {
3852 Self {
@@ -41,45 +55,47 @@ impl Notifier {
4155 }
4256 }
4357
58+ fn propagate_future_state_to_notify_flag ( & self ) -> MutexGuard < ( bool , Option < Arc < Mutex < FutureState > > > ) > {
59+ let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
60+ if let Some ( existing_state) = & lock. 1 {
61+ if existing_state. lock ( ) . unwrap ( ) . callbacks_made {
62+ // If the existing `FutureState` has completed and actually made callbacks,
63+ // consider the notification flag to have been cleared and reset the future state.
64+ lock. 1 . take ( ) ;
65+ lock. 0 = false ;
66+ }
67+ }
68+ lock
69+ }
70+
4471 pub ( crate ) fn wait ( & self ) {
4572 loop {
46- let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
47- if guard. 0 {
48- guard. 0 = false ;
49- return ;
50- }
73+ let mut guard = self . propagate_future_state_to_notify_flag ( ) ;
74+ check_woken ! ( guard, ( ) ) ;
5175 guard = self . condvar . wait ( guard) . unwrap ( ) ;
52- let result = guard. 0 ;
53- if result {
54- guard. 0 = false ;
55- return
56- }
76+ check_woken ! ( guard, ( ) ) ;
5777 }
5878 }
5979
6080 #[ cfg( any( test, feature = "std" ) ) ]
6181 pub ( crate ) fn wait_timeout ( & self , max_wait : Duration ) -> bool {
6282 let current_time = Instant :: now ( ) ;
6383 loop {
64- let mut guard = self . notify_pending . lock ( ) . unwrap ( ) ;
65- if guard. 0 {
66- guard. 0 = false ;
67- return true ;
68- }
84+ let mut guard = self . propagate_future_state_to_notify_flag ( ) ;
85+ check_woken ! ( guard, true ) ;
6986 guard = self . condvar . wait_timeout ( guard, max_wait) . unwrap ( ) . 0 ;
87+ check_woken ! ( guard, true ) ;
7088 // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
7189 // desired wait time has actually passed, and if not then restart the loop with a reduced wait
7290 // time. Note that this logic can be highly simplified through the use of
7391 // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
7492 // 1.42.0.
7593 let elapsed = current_time. elapsed ( ) ;
76- let result = guard. 0 ;
77- if result || elapsed >= max_wait {
78- guard. 0 = false ;
79- return result;
94+ if elapsed >= max_wait {
95+ return false ;
8096 }
8197 match max_wait. checked_sub ( elapsed) {
82- None => return result ,
98+ None => return false ,
8399 Some ( _) => continue
84100 }
85101 }
@@ -88,17 +104,8 @@ impl Notifier {
88104 /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
89105 pub ( crate ) fn notify ( & self ) {
90106 let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
91- let mut future_probably_generated_calls = false ;
92- if let Some ( future_state) = lock. 1 . take ( ) {
93- future_probably_generated_calls |= future_state. lock ( ) . unwrap ( ) . complete ( ) ;
94- future_probably_generated_calls |= Arc :: strong_count ( & future_state) > 1 ;
95- }
96- if future_probably_generated_calls {
97- // If a future made some callbacks or has not yet been drop'd (i.e. the state has more
98- // than the one reference we hold), assume the user was notified and skip setting the
99- // notification-required flag. This will not cause the `wait` functions above to return
100- // and avoid any future `Future`s starting in a completed state.
101- return ;
107+ if let Some ( future_state) = & lock. 1 {
108+ future_state. lock ( ) . unwrap ( ) . complete ( ) ;
102109 }
103110 lock. 0 = true ;
104111 mem:: drop ( lock) ;
@@ -107,20 +114,14 @@ impl Notifier {
107114
108115 /// Gets a [`Future`] that will get woken up with any waiters
109116 pub ( crate ) fn get_future ( & self ) -> Future {
110- let mut lock = self . notify_pending . lock ( ) . unwrap ( ) ;
111- if lock. 0 {
112- Future {
113- state : Arc :: new ( Mutex :: new ( FutureState {
114- callbacks : Vec :: new ( ) ,
115- complete : true ,
116- } ) )
117- }
118- } else if let Some ( existing_state) = & lock. 1 {
117+ let mut lock = self . propagate_future_state_to_notify_flag ( ) ;
118+ if let Some ( existing_state) = & lock. 1 {
119119 Future { state : Arc :: clone ( & existing_state) }
120120 } else {
121121 let state = Arc :: new ( Mutex :: new ( FutureState {
122122 callbacks : Vec :: new ( ) ,
123- complete : false ,
123+ complete : lock. 0 ,
124+ callbacks_made : false ,
124125 } ) ) ;
125126 lock. 1 = Some ( Arc :: clone ( & state) ) ;
126127 Future { state }
@@ -151,19 +152,21 @@ impl<F: Fn() + Send> FutureCallback for F {
151152}
152153
153154pub ( crate ) struct FutureState {
154- callbacks : Vec < Box < dyn FutureCallback > > ,
155+ // When we're tracking whether a callback counts as having woken the user's code, we check the
156+ // first bool - set to false if we're just calling a Waker, and true if we're calling an actual
157+ // user-provided function.
158+ callbacks : Vec < ( bool , Box < dyn FutureCallback > ) > ,
155159 complete : bool ,
160+ callbacks_made : bool ,
156161}
157162
158163impl FutureState {
159- fn complete ( & mut self ) -> bool {
160- let mut made_calls = false ;
161- for callback in self . callbacks . drain ( ..) {
164+ fn complete ( & mut self ) {
165+ for ( counts_as_call, callback) in self . callbacks . drain ( ..) {
162166 callback. call ( ) ;
163- made_calls = true ;
167+ self . callbacks_made |= counts_as_call ;
164168 }
165169 self . complete = true ;
166- made_calls
167170 }
168171}
169172
@@ -180,10 +183,11 @@ impl Future {
180183 pub fn register_callback ( & self , callback : Box < dyn FutureCallback > ) {
181184 let mut state = self . state . lock ( ) . unwrap ( ) ;
182185 if state. complete {
186+ state. callbacks_made = true ;
183187 mem:: drop ( state) ;
184188 callback. call ( ) ;
185189 } else {
186- state. callbacks . push ( callback) ;
190+ state. callbacks . push ( ( true , callback) ) ;
187191 }
188192 }
189193
@@ -198,12 +202,10 @@ impl Future {
198202 }
199203}
200204
201- mod std_future {
202- use core:: task:: Waker ;
203- pub struct StdWaker ( pub Waker ) ;
204- impl super :: FutureCallback for StdWaker {
205- fn call ( & self ) { self . 0 . wake_by_ref ( ) }
206- }
205+ use core:: task:: Waker ;
206+ struct StdWaker ( pub Waker ) ;
207+ impl FutureCallback for StdWaker {
208+ fn call ( & self ) { self . 0 . wake_by_ref ( ) }
207209}
208210
209211/// (C-not exported) as Rust Futures aren't usable in language bindings.
@@ -213,10 +215,11 @@ impl<'a> StdFuture for Future {
213215 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
214216 let mut state = self . state . lock ( ) . unwrap ( ) ;
215217 if state. complete {
218+ state. callbacks_made = true ;
216219 Poll :: Ready ( ( ) )
217220 } else {
218221 let waker = cx. waker ( ) . clone ( ) ;
219- state. callbacks . push ( Box :: new ( std_future :: StdWaker ( waker) ) ) ;
222+ state. callbacks . push ( ( false , Box :: new ( StdWaker ( waker) ) ) ) ;
220223 Poll :: Pending
221224 }
222225 }
@@ -285,6 +288,28 @@ mod tests {
285288 assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
286289 }
287290
291+ #[ test]
292+ fn new_future_wipes_notify_bit ( ) {
293+ // Previously, if we were only using the `Future` interface to learn when a `Notifier` has
294+ // been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
295+ // fetched after the notify bit has been set.
296+ let notifier = Notifier :: new ( ) ;
297+ notifier. notify ( ) ;
298+
299+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
300+ let callback_ref = Arc :: clone ( & callback) ;
301+ notifier. get_future ( ) . register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
302+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
303+
304+ let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
305+ let callback_ref = Arc :: clone ( & callback) ;
306+ notifier. get_future ( ) . register_callback ( Box :: new ( move || assert ! ( !callback_ref. fetch_or( true , Ordering :: SeqCst ) ) ) ) ;
307+ assert ! ( !callback. load( Ordering :: SeqCst ) ) ;
308+
309+ notifier. notify ( ) ;
310+ assert ! ( callback. load( Ordering :: SeqCst ) ) ;
311+ }
312+
288313 #[ cfg( feature = "std" ) ]
289314 #[ test]
290315 fn test_wait_timeout ( ) {
@@ -336,6 +361,7 @@ mod tests {
336361 state : Arc :: new ( Mutex :: new ( FutureState {
337362 callbacks : Vec :: new ( ) ,
338363 complete : false ,
364+ callbacks_made : false ,
339365 } ) )
340366 } ;
341367 let callback = Arc :: new ( AtomicBool :: new ( false ) ) ;
@@ -354,6 +380,7 @@ mod tests {
354380 state : Arc :: new ( Mutex :: new ( FutureState {
355381 callbacks : Vec :: new ( ) ,
356382 complete : false ,
383+ callbacks_made : false ,
357384 } ) )
358385 } ;
359386 future. state . lock ( ) . unwrap ( ) . complete ( ) ;
@@ -391,6 +418,7 @@ mod tests {
391418 state : Arc :: new ( Mutex :: new ( FutureState {
392419 callbacks : Vec :: new ( ) ,
393420 complete : false ,
421+ callbacks_made : false ,
394422 } ) )
395423 } ;
396424 let mut second_future = Future { state : Arc :: clone ( & future. state ) } ;
@@ -409,4 +437,36 @@ mod tests {
409437 assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Ready ( ( ) ) ) ;
410438 assert_eq ! ( Pin :: new( & mut second_future) . poll( & mut Context :: from_waker( & second_waker) ) , Poll :: Ready ( ( ) ) ) ;
411439 }
440+
441+ #[ test]
442+ fn test_dropped_future_doesnt_count ( ) {
443+ // Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as
444+ // having been woken, leaving the notify-required flag set.
445+ let notifier = Notifier :: new ( ) ;
446+ notifier. notify ( ) ;
447+
448+ // If we get a future and don't touch it we're definitely still notify-required.
449+ notifier. get_future ( ) ;
450+ assert ! ( notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
451+ assert ! ( !notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
452+
453+ // Even if we poll'd once but didn't observe a `Ready`, we should be notify-required.
454+ let mut future = notifier. get_future ( ) ;
455+ let ( woken, waker) = create_waker ( ) ;
456+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Pending ) ;
457+
458+ notifier. notify ( ) ;
459+ assert ! ( woken. load( Ordering :: SeqCst ) ) ;
460+ assert ! ( notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
461+
462+ // However, once we do poll `Ready` it should wipe the notify-required flag.
463+ let mut future = notifier. get_future ( ) ;
464+ let ( woken, waker) = create_waker ( ) ;
465+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Pending ) ;
466+
467+ notifier. notify ( ) ;
468+ assert ! ( woken. load( Ordering :: SeqCst ) ) ;
469+ assert_eq ! ( Pin :: new( & mut future) . poll( & mut Context :: from_waker( & waker) ) , Poll :: Ready ( ( ) ) ) ;
470+ assert ! ( !notifier. wait_timeout( Duration :: from_millis( 1 ) ) ) ;
471+ }
412472}
0 commit comments