Skip to content

Commit d273435

Browse files
authored
Drop task future before fulfill join handle (#4)
This way after `task::spawn(future).await`, we know that all values contained in the `future` are dropped. Resolves #3.
1 parent 6d75a28 commit d273435

File tree

1 file changed

+64
-24
lines changed

1 file changed

+64
-24
lines changed

spawns-core/src/task.rs

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -278,28 +278,33 @@ impl<F: Future> Future for IdFuture<F> {
278278
}
279279

280280
struct TaskFuture<F: Future> {
281-
ready: bool,
282281
waker: Option<Box<Waker>>,
283282
cancellation: Arc<Cancellation>,
284283
joint: Arc<Joint<F::Output>>,
285-
future: F,
284+
future: Option<F>,
286285
}
287286

288287
impl<F: Future> TaskFuture<F> {
289288
fn new(future: F) -> Self {
290289
Self {
291-
ready: false,
292290
waker: None,
293291
joint: Arc::new(Joint::new()),
294292
cancellation: Arc::new(Default::default()),
295-
future,
293+
future: Some(future),
296294
}
297295
}
296+
297+
fn finish(&mut self, value: Result<F::Output, InnerJoinError>) -> Poll<()> {
298+
self.future = None;
299+
self.joint.wake(value);
300+
Poll::Ready(())
301+
}
298302
}
303+
299304
impl<F: Future> Drop for TaskFuture<F> {
300305
fn drop(&mut self) {
301-
if !self.ready {
302-
self.joint.wake(Err(InnerJoinError::Cancelled));
306+
if self.future.is_some() {
307+
let _ = self.finish(Err(InnerJoinError::Cancelled));
303308
}
304309
}
305310
}
@@ -309,14 +314,12 @@ impl<F: Future> Future for TaskFuture<F> {
309314

310315
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
311316
let task = unsafe { self.get_unchecked_mut() };
312-
if task.ready {
317+
if task.future.is_none() {
313318
return Poll::Ready(());
314319
} else if task.cancellation.is_cancelled() {
315-
task.joint.wake(Err(InnerJoinError::Cancelled));
316-
task.ready = true;
317-
return Poll::Ready(());
320+
return task.finish(Err(InnerJoinError::Cancelled));
318321
}
319-
let future = unsafe { Pin::new_unchecked(&mut task.future) };
322+
let future = unsafe { Pin::new_unchecked(task.future.as_mut().unwrap_unchecked()) };
320323
match panic::catch_unwind(AssertUnwindSafe(|| future.poll(cx))) {
321324
Ok(Poll::Pending) => {
322325
let waker = match task.waker.take() {
@@ -327,23 +330,13 @@ impl<F: Future> Future for TaskFuture<F> {
327330
}
328331
};
329332
let Ok(waker) = task.cancellation.update(waker) else {
330-
task.joint.wake(Err(InnerJoinError::Cancelled));
331-
task.ready = true;
332-
return Poll::Ready(());
333+
return task.finish(Err(InnerJoinError::Cancelled));
333334
};
334335
task.waker = waker;
335336
Poll::Pending
336337
}
337-
Ok(Poll::Ready(value)) => {
338-
task.joint.wake(Ok(value));
339-
task.ready = true;
340-
Poll::Ready(())
341-
}
342-
Err(err) => {
343-
task.joint.wake(Err(InnerJoinError::Panic(err)));
344-
task.ready = true;
345-
Poll::Ready(())
346-
}
338+
Ok(Poll::Ready(value)) => task.finish(Ok(value)),
339+
Err(err) => task.finish(Err(InnerJoinError::Panic(err))),
347340
}
348341
}
349342
}
@@ -999,4 +992,51 @@ mod tests {
999992
block_on(Box::into_pin(task.future));
1000993
assert_eq!(cancelled.load(Ordering::Relaxed), false);
1001994
}
995+
996+
struct CustomFuture {
997+
_shared: Arc<()>,
998+
}
999+
1000+
impl Future for CustomFuture {
1001+
type Output = ();
1002+
1003+
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
1004+
Poll::Ready(())
1005+
}
1006+
}
1007+
1008+
#[test]
1009+
fn future_dropped_before_ready() {
1010+
let shared = Arc::new(());
1011+
let (mut task, _handle) = Task::new(
1012+
Name::default(),
1013+
CustomFuture {
1014+
_shared: shared.clone(),
1015+
},
1016+
);
1017+
let pinned = unsafe { Pin::new_unchecked(task.future.as_mut()) };
1018+
let poll = pinned.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));
1019+
assert!(poll.is_ready());
1020+
assert_eq!(Arc::strong_count(&shared), 1);
1021+
}
1022+
1023+
#[test]
1024+
fn future_dropped_before_joined() {
1025+
let shared = Arc::new(());
1026+
let (mut task, handle) = Task::new(
1027+
Name::default(),
1028+
CustomFuture {
1029+
_shared: shared.clone(),
1030+
},
1031+
);
1032+
std::thread::spawn(move || {
1033+
let pinned = unsafe { Pin::new_unchecked(task.future.as_mut()) };
1034+
let _poll = pinned.poll(&mut Context::from_waker(futures::task::noop_waker_ref()));
1035+
1036+
// Let join handle complete before task drop.
1037+
std::thread::sleep(Duration::from_millis(10));
1038+
});
1039+
block_on(handle).unwrap();
1040+
assert_eq!(Arc::strong_count(&shared), 1);
1041+
}
10021042
}

0 commit comments

Comments
 (0)