1515
1616use  alloc:: sync:: Arc ; 
1717use  core:: mem; 
18- use  crate :: sync:: { Condvar ,  Mutex } ; 
18+ use  crate :: sync:: { Condvar ,  Mutex ,   MutexGuard } ; 
1919
2020use  crate :: prelude:: * ; 
2121
@@ -41,9 +41,22 @@ impl Notifier {
4141		} 
4242	} 
4343
44+ 	fn  propagate_future_state_to_notify_flag ( & self )  -> MutexGuard < ( bool ,  Option < Arc < Mutex < FutureState > > > ) >  { 
45+ 		let  mut  lock = self . notify_pending . lock ( ) . unwrap ( ) ; 
46+ 		if  let  Some ( existing_state)  = & lock. 1  { 
47+ 			if  existing_state. lock ( ) . unwrap ( ) . callbacks_made  { 
48+ 				// If the existing `FutureState` has completed and actually made callbacks, 
49+ 				// consider the notification flag to have been cleared and reset the future state. 
50+ 				lock. 1 . take ( ) ; 
51+ 				lock. 0  = false ; 
52+ 			} 
53+ 		} 
54+ 		lock
55+ 	} 
56+ 
4457	pub ( crate )  fn  wait ( & self )  { 
4558		loop  { 
46- 			let  mut  guard = self . notify_pending . lock ( ) . unwrap ( ) ; 
59+ 			let  mut  guard = self . propagate_future_state_to_notify_flag ( ) ; 
4760			if  guard. 0  { 
4861				guard. 0  = false ; 
4962				return ; 
@@ -61,7 +74,7 @@ impl Notifier {
6174	pub ( crate )  fn  wait_timeout ( & self ,  max_wait :  Duration )  -> bool  { 
6275		let  current_time = Instant :: now ( ) ; 
6376		loop  { 
64- 			let  mut  guard = self . notify_pending . lock ( ) . unwrap ( ) ; 
77+ 			let  mut  guard = self . propagate_future_state_to_notify_flag ( ) ; 
6578			if  guard. 0  { 
6679				guard. 0  = false ; 
6780				return  true ; 
@@ -88,17 +101,8 @@ impl Notifier {
88101	/// Wake waiters, tracking that wake needs to occur even if there are currently no waiters. 
89102pub ( crate )  fn  notify ( & self )  { 
90103		let  mut  lock = self . notify_pending . lock ( ) . unwrap ( ) ; 
91- 		let  mut  future_probably_generated_calls = false ; 
92- 		if  let  Some ( future_state)  = lock. 1 . take ( )  { 
93- 			future_probably_generated_calls |= future_state. lock ( ) . unwrap ( ) . complete ( ) ; 
94- 			future_probably_generated_calls |= Arc :: strong_count ( & future_state)  > 1 ; 
95- 		} 
96- 		if  future_probably_generated_calls { 
97- 			// If a future made some callbacks or has not yet been drop'd (i.e. the state has more 
98- 			// than the one reference we hold), assume the user was notified and skip setting the 
99- 			// notification-required flag. This will not cause the `wait` functions above to return 
100- 			// and avoid any future `Future`s starting in a completed state. 
101- 			return ; 
104+ 		if  let  Some ( future_state)  = & lock. 1  { 
105+ 			future_state. lock ( ) . unwrap ( ) . complete ( ) ; 
102106		} 
103107		lock. 0  = true ; 
104108		mem:: drop ( lock) ; 
@@ -107,20 +111,14 @@ impl Notifier {
107111
108112	/// Gets a [`Future`] that will get woken up with any waiters 
109113pub ( crate )  fn  get_future ( & self )  -> Future  { 
110- 		let  mut  lock = self . notify_pending . lock ( ) . unwrap ( ) ; 
111- 		if  lock. 0  { 
112- 			Future  { 
113- 				state :  Arc :: new ( Mutex :: new ( FutureState  { 
114- 					callbacks :  Vec :: new ( ) , 
115- 					complete :  true , 
116- 				} ) ) 
117- 			} 
118- 		}  else  if  let  Some ( existing_state)  = & lock. 1  { 
114+ 		let  mut  lock = self . propagate_future_state_to_notify_flag ( ) ; 
115+ 		if  let  Some ( existing_state)  = & lock. 1  { 
119116			Future  {  state :  Arc :: clone ( & existing_state)  } 
120117		}  else  { 
121118			let  state = Arc :: new ( Mutex :: new ( FutureState  { 
122119				callbacks :  Vec :: new ( ) , 
123- 				complete :  false , 
120+ 				complete :  lock. 0 , 
121+ 				callbacks_made :  false , 
124122			} ) ) ; 
125123			lock. 1  = Some ( Arc :: clone ( & state) ) ; 
126124			Future  {  state } 
@@ -153,17 +151,16 @@ impl<F: Fn() + Send> FutureCallback for F {
153151pub ( crate )  struct  FutureState  { 
154152	callbacks :  Vec < Box < dyn  FutureCallback > > , 
155153	complete :  bool , 
154+ 	callbacks_made :  bool , 
156155} 
157156
158157impl  FutureState  { 
159- 	fn  complete ( & mut  self )  -> bool  { 
160- 		let  mut  made_calls = false ; 
158+ 	fn  complete ( & mut  self )  { 
161159		for  callback in  self . callbacks . drain ( ..)  { 
162160			callback. call ( ) ; 
163- 			made_calls  = true ; 
161+ 			self . callbacks_made  = true ; 
164162		} 
165163		self . complete  = true ; 
166- 		made_calls
167164	} 
168165} 
169166
@@ -180,6 +177,7 @@ impl Future {
180177pub  fn  register_callback ( & self ,  callback :  Box < dyn  FutureCallback > )  { 
181178		let  mut  state = self . state . lock ( ) . unwrap ( ) ; 
182179		if  state. complete  { 
180+ 			state. callbacks_made  = true ; 
183181			mem:: drop ( state) ; 
184182			callback. call ( ) ; 
185183		}  else  { 
@@ -283,6 +281,28 @@ mod tests {
283281		assert ! ( !callback. load( Ordering :: SeqCst ) ) ; 
284282	} 
285283
284+ 	#[ test]  
285+ 	fn  new_future_wipes_notify_bit ( )  { 
286+ 		// Previously, if we were only using the `Future` interface to learn when a `Notifier` has 
287+ 		// been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is 
288+ 		// fetched after the notify bit has been set. 
289+ 		let  notifier = Notifier :: new ( ) ; 
290+ 		notifier. notify ( ) ; 
291+ 
292+ 		let  callback = Arc :: new ( AtomicBool :: new ( false ) ) ; 
293+ 		let  callback_ref = Arc :: clone ( & callback) ; 
294+ 		notifier. get_future ( ) . register_callback ( Box :: new ( move  || assert ! ( !callback_ref. fetch_or( true ,  Ordering :: SeqCst ) ) ) ) ; 
295+ 		assert ! ( callback. load( Ordering :: SeqCst ) ) ; 
296+ 
297+ 		let  callback = Arc :: new ( AtomicBool :: new ( false ) ) ; 
298+ 		let  callback_ref = Arc :: clone ( & callback) ; 
299+ 		notifier. get_future ( ) . register_callback ( Box :: new ( move  || assert ! ( !callback_ref. fetch_or( true ,  Ordering :: SeqCst ) ) ) ) ; 
300+ 		assert ! ( !callback. load( Ordering :: SeqCst ) ) ; 
301+ 
302+ 		notifier. notify ( ) ; 
303+ 		assert ! ( callback. load( Ordering :: SeqCst ) ) ; 
304+ 	} 
305+ 
286306	#[ cfg( feature = "std" ) ]  
287307	#[ test]  
288308	fn  test_wait_timeout ( )  { 
@@ -334,6 +354,7 @@ mod tests {
334354			state :  Arc :: new ( Mutex :: new ( FutureState  { 
335355				callbacks :  Vec :: new ( ) , 
336356				complete :  false , 
357+ 				callbacks_made :  false , 
337358			} ) ) 
338359		} ; 
339360		let  callback = Arc :: new ( AtomicBool :: new ( false ) ) ; 
@@ -352,6 +373,7 @@ mod tests {
352373			state :  Arc :: new ( Mutex :: new ( FutureState  { 
353374				callbacks :  Vec :: new ( ) , 
354375				complete :  false , 
376+ 				callbacks_made :  false , 
355377			} ) ) 
356378		} ; 
357379		future. state . lock ( ) . unwrap ( ) . complete ( ) ; 
@@ -389,6 +411,7 @@ mod tests {
389411			state :  Arc :: new ( Mutex :: new ( FutureState  { 
390412				callbacks :  Vec :: new ( ) , 
391413				complete :  false , 
414+ 				callbacks_made :  false , 
392415			} ) ) 
393416		} ; 
394417		let  mut  second_future = Future  {  state :  Arc :: clone ( & future. state )  } ; 
0 commit comments