11use std:: future:: Future ;
22use std:: hash:: Hash ;
3- use std:: mem;
43use std:: pin:: Pin ;
54use std:: task:: { Context , Poll , Waker } ;
65use std:: time:: Duration ;
6+ use std:: { future, mem} ;
77
88use futures_timer:: Delay ;
99use futures_util:: future:: BoxFuture ;
@@ -38,6 +38,7 @@ impl<ID, O> FuturesMap<ID, O> {
3838impl < ID , O > FuturesMap < ID , O >
3939where
4040 ID : Clone + Hash + Eq + Send + Unpin + ' static ,
41+ O : ' static ,
4142{
4243 /// Push a future into the map.
4344 ///
@@ -58,32 +59,30 @@ where
5859 waker. wake ( ) ;
5960 }
6061
61- match self . inner . iter_mut ( ) . find ( |tagged| tagged. tag == future_id) {
62- None => {
63- self . inner . push ( TaggedFuture {
64- tag : future_id,
65- inner : TimeoutFuture {
66- inner : future. boxed ( ) ,
67- timeout : Delay :: new ( self . timeout ) ,
68- } ,
69- } ) ;
70-
71- Ok ( ( ) )
72- }
73- Some ( existing) => {
74- let old_future = mem:: replace (
75- & mut existing. inner ,
76- TimeoutFuture {
77- inner : future. boxed ( ) ,
78- timeout : Delay :: new ( self . timeout ) ,
79- } ,
80- ) ;
81-
82- Err ( PushError :: Replaced ( old_future. inner ) )
83- }
62+ let old = self . remove ( future_id. clone ( ) ) ;
63+ self . inner . push ( TaggedFuture {
64+ tag : future_id,
65+ inner : TimeoutFuture {
66+ inner : future. boxed ( ) ,
67+ timeout : Delay :: new ( self . timeout ) ,
68+ cancelled : false ,
69+ } ,
70+ } ) ;
71+ match old {
72+ None => Ok ( ( ) ) ,
73+ Some ( old) => Err ( PushError :: Replaced ( old) ) ,
8474 }
8575 }
8676
77+ pub fn remove ( & mut self , id : ID ) -> Option < BoxFuture < ' static , O > > {
78+ let tagged = self . inner . iter_mut ( ) . find ( |s| s. tag == id) ?;
79+
80+ let inner = mem:: replace ( & mut tagged. inner . inner , future:: pending ( ) . boxed ( ) ) ;
81+ tagged. inner . cancelled = true ;
82+
83+ Some ( inner)
84+ }
85+
8786 pub fn len ( & self ) -> usize {
8887 self . inner . len ( )
8988 }
@@ -104,39 +103,55 @@ where
104103 }
105104
106105 pub fn poll_unpin ( & mut self , cx : & mut Context < ' _ > ) -> Poll < ( ID , Result < O , Timeout > ) > {
107- let maybe_result = futures_util:: ready!( self . inner. poll_next_unpin( cx) ) ;
106+ loop {
107+ let maybe_result = futures_util:: ready!( self . inner. poll_next_unpin( cx) ) ;
108108
109- match maybe_result {
110- None => {
111- self . empty_waker = Some ( cx. waker ( ) . clone ( ) ) ;
112- Poll :: Pending
109+ match maybe_result {
110+ None => {
111+ self . empty_waker = Some ( cx. waker ( ) . clone ( ) ) ;
112+ return Poll :: Pending ;
113+ }
114+ Some ( ( id, Ok ( output) ) ) => return Poll :: Ready ( ( id, Ok ( output) ) ) ,
115+ Some ( ( id, Err ( TimeoutError :: Timeout ) ) ) => {
116+ return Poll :: Ready ( ( id, Err ( Timeout :: new ( self . timeout ) ) ) )
117+ }
118+ Some ( ( _, Err ( TimeoutError :: Cancelled ) ) ) => continue ,
113119 }
114- Some ( ( id, Ok ( output) ) ) => Poll :: Ready ( ( id, Ok ( output) ) ) ,
115- Some ( ( id, Err ( _timeout) ) ) => Poll :: Ready ( ( id, Err ( Timeout :: new ( self . timeout ) ) ) ) ,
116120 }
117121 }
118122}
119123
120124struct TimeoutFuture < F > {
121125 inner : F ,
122126 timeout : Delay ,
127+
128+ cancelled : bool ,
123129}
124130
125131impl < F > Future for TimeoutFuture < F >
126132where
127133 F : Future + Unpin ,
128134{
129- type Output = Result < F :: Output , ( ) > ;
135+ type Output = Result < F :: Output , TimeoutError > ;
130136
131137 fn poll ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
138+ if self . cancelled {
139+ return Poll :: Ready ( Err ( TimeoutError :: Cancelled ) ) ;
140+ }
141+
132142 if self . timeout . poll_unpin ( cx) . is_ready ( ) {
133- return Poll :: Ready ( Err ( ( ) ) ) ;
143+ return Poll :: Ready ( Err ( TimeoutError :: Timeout ) ) ;
134144 }
135145
136146 self . inner . poll_unpin ( cx) . map ( Ok )
137147 }
138148}
139149
150+ enum TimeoutError {
151+ Timeout ,
152+ Cancelled ,
153+ }
154+
140155struct TaggedFuture < T , F > {
141156 tag : T ,
142157 inner : F ,
@@ -158,6 +173,8 @@ where
158173
159174#[ cfg( test) ]
160175mod tests {
176+ use futures:: channel:: oneshot;
177+ use futures_util:: task:: noop_waker_ref;
161178 use std:: future:: { pending, poll_fn, ready} ;
162179 use std:: pin:: Pin ;
163180 use std:: time:: Instant ;
@@ -197,6 +214,45 @@ mod tests {
197214 assert ! ( result. is_err( ) )
198215 }
199216
217+ #[ test]
218+ fn resources_of_removed_future_are_cleaned_up ( ) {
219+ let mut futures = FuturesMap :: new ( Duration :: from_millis ( 100 ) , 1 ) ;
220+
221+ let _ = futures. try_push ( "ID" , pending :: < ( ) > ( ) ) ;
222+ futures. remove ( "ID" ) ;
223+
224+ let poll = futures. poll_unpin ( & mut Context :: from_waker ( noop_waker_ref ( ) ) ) ;
225+ assert ! ( poll. is_pending( ) ) ;
226+
227+ assert_eq ! ( futures. len( ) , 0 ) ;
228+ }
229+
230+ #[ tokio:: test]
231+ async fn replaced_pending_future_is_polled ( ) {
232+ let mut streams = FuturesMap :: new ( Duration :: from_millis ( 100 ) , 3 ) ;
233+
234+ let ( _tx1, rx1) = oneshot:: channel ( ) ;
235+ let ( tx2, rx2) = oneshot:: channel ( ) ;
236+
237+ let _ = streams. try_push ( "ID1" , rx1) ;
238+ let _ = streams. try_push ( "ID2" , rx2) ;
239+
240+ let _ = tx2. send ( 2 ) ;
241+ let ( id, res) = poll_fn ( |cx| streams. poll_unpin ( cx) ) . await ;
242+ assert_eq ! ( id, "ID2" ) ;
243+ assert_eq ! ( res. unwrap( ) . unwrap( ) , 2 ) ;
244+
245+ let ( new_tx1, new_rx1) = oneshot:: channel ( ) ;
246+ let replaced = streams. try_push ( "ID1" , new_rx1) ;
247+ assert ! ( matches!( replaced. unwrap_err( ) , PushError :: Replaced ( _) ) ) ;
248+
249+ let _ = new_tx1. send ( 4 ) ;
250+ let ( id, res) = poll_fn ( |cx| streams. poll_unpin ( cx) ) . await ;
251+
252+ assert_eq ! ( id, "ID1" ) ;
253+ assert_eq ! ( res. unwrap( ) . unwrap( ) , 4 ) ;
254+ }
255+
200256 // Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
201257 // We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES.
202258 #[ tokio:: test]
0 commit comments