@@ -278,28 +278,33 @@ impl<F: Future> Future for IdFuture<F> {
278
278
}
279
279
280
280
struct TaskFuture < F : Future > {
281
- ready : bool ,
282
281
waker : Option < Box < Waker > > ,
283
282
cancellation : Arc < Cancellation > ,
284
283
joint : Arc < Joint < F :: Output > > ,
285
- future : F ,
284
+ future : Option < F > ,
286
285
}
287
286
288
287
impl < F : Future > TaskFuture < F > {
289
288
fn new ( future : F ) -> Self {
290
289
Self {
291
- ready : false ,
292
290
waker : None ,
293
291
joint : Arc :: new ( Joint :: new ( ) ) ,
294
292
cancellation : Arc :: new ( Default :: default ( ) ) ,
295
- future,
293
+ future : Some ( future ) ,
296
294
}
297
295
}
296
+
297
+ fn finish ( & mut self , value : Result < F :: Output , InnerJoinError > ) -> Poll < ( ) > {
298
+ self . future = None ;
299
+ self . joint . wake ( value) ;
300
+ Poll :: Ready ( ( ) )
301
+ }
298
302
}
303
+
299
304
impl < F : Future > Drop for TaskFuture < F > {
300
305
fn drop ( & mut self ) {
301
- if ! self . ready {
302
- self . joint . wake ( Err ( InnerJoinError :: Cancelled ) ) ;
306
+ if self . future . is_some ( ) {
307
+ let _ = self . finish ( Err ( InnerJoinError :: Cancelled ) ) ;
303
308
}
304
309
}
305
310
}
@@ -309,14 +314,12 @@ impl<F: Future> Future for TaskFuture<F> {
309
314
310
315
fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
311
316
let task = unsafe { self . get_unchecked_mut ( ) } ;
312
- if task. ready {
317
+ if task. future . is_none ( ) {
313
318
return Poll :: Ready ( ( ) ) ;
314
319
} else if task. cancellation . is_cancelled ( ) {
315
- task. joint . wake ( Err ( InnerJoinError :: Cancelled ) ) ;
316
- task. ready = true ;
317
- return Poll :: Ready ( ( ) ) ;
320
+ return task. finish ( Err ( InnerJoinError :: Cancelled ) ) ;
318
321
}
319
- let future = unsafe { Pin :: new_unchecked ( & mut task. future ) } ;
322
+ let future = unsafe { Pin :: new_unchecked ( task. future . as_mut ( ) . unwrap_unchecked ( ) ) } ;
320
323
match panic:: catch_unwind ( AssertUnwindSafe ( || future. poll ( cx) ) ) {
321
324
Ok ( Poll :: Pending ) => {
322
325
let waker = match task. waker . take ( ) {
@@ -327,23 +330,13 @@ impl<F: Future> Future for TaskFuture<F> {
327
330
}
328
331
} ;
329
332
let Ok ( waker) = task. cancellation . update ( waker) else {
330
- task. joint . wake ( Err ( InnerJoinError :: Cancelled ) ) ;
331
- task. ready = true ;
332
- return Poll :: Ready ( ( ) ) ;
333
+ return task. finish ( Err ( InnerJoinError :: Cancelled ) ) ;
333
334
} ;
334
335
task. waker = waker;
335
336
Poll :: Pending
336
337
}
337
- Ok ( Poll :: Ready ( value) ) => {
338
- task. joint . wake ( Ok ( value) ) ;
339
- task. ready = true ;
340
- Poll :: Ready ( ( ) )
341
- }
342
- Err ( err) => {
343
- task. joint . wake ( Err ( InnerJoinError :: Panic ( err) ) ) ;
344
- task. ready = true ;
345
- Poll :: Ready ( ( ) )
346
- }
338
+ Ok ( Poll :: Ready ( value) ) => task. finish ( Ok ( value) ) ,
339
+ Err ( err) => task. finish ( Err ( InnerJoinError :: Panic ( err) ) ) ,
347
340
}
348
341
}
349
342
}
@@ -999,4 +992,51 @@ mod tests {
999
992
block_on ( Box :: into_pin ( task. future ) ) ;
1000
993
assert_eq ! ( cancelled. load( Ordering :: Relaxed ) , false ) ;
1001
994
}
995
+
996
+ struct CustomFuture {
997
+ _shared : Arc < ( ) > ,
998
+ }
999
+
1000
+ impl Future for CustomFuture {
1001
+ type Output = ( ) ;
1002
+
1003
+ fn poll ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
1004
+ Poll :: Ready ( ( ) )
1005
+ }
1006
+ }
1007
+
1008
+ #[ test]
1009
+ fn future_dropped_before_ready ( ) {
1010
+ let shared = Arc :: new ( ( ) ) ;
1011
+ let ( mut task, _handle) = Task :: new (
1012
+ Name :: default ( ) ,
1013
+ CustomFuture {
1014
+ _shared : shared. clone ( ) ,
1015
+ } ,
1016
+ ) ;
1017
+ let pinned = unsafe { Pin :: new_unchecked ( task. future . as_mut ( ) ) } ;
1018
+ let poll = pinned. poll ( & mut Context :: from_waker ( futures:: task:: noop_waker_ref ( ) ) ) ;
1019
+ assert ! ( poll. is_ready( ) ) ;
1020
+ assert_eq ! ( Arc :: strong_count( & shared) , 1 ) ;
1021
+ }
1022
+
1023
+ #[ test]
1024
+ fn future_dropped_before_joined ( ) {
1025
+ let shared = Arc :: new ( ( ) ) ;
1026
+ let ( mut task, handle) = Task :: new (
1027
+ Name :: default ( ) ,
1028
+ CustomFuture {
1029
+ _shared : shared. clone ( ) ,
1030
+ } ,
1031
+ ) ;
1032
+ std:: thread:: spawn ( move || {
1033
+ let pinned = unsafe { Pin :: new_unchecked ( task. future . as_mut ( ) ) } ;
1034
+ let _poll = pinned. poll ( & mut Context :: from_waker ( futures:: task:: noop_waker_ref ( ) ) ) ;
1035
+
1036
+ // Let join handle complete before task drop.
1037
+ std:: thread:: sleep ( Duration :: from_millis ( 10 ) ) ;
1038
+ } ) ;
1039
+ block_on ( handle) . unwrap ( ) ;
1040
+ assert_eq ! ( Arc :: strong_count( & shared) , 1 ) ;
1041
+ }
1002
1042
}
0 commit comments