Skip to content

Commit a2c1b4e

Browse files
committed
futures: TaskGuard
## Description An abstraction over `JoinHandle<T>` that aborts the task when it is dropped. ## Test plan New unit tests: ``` sui-futures$ cargo nextest run ```
1 parent 980f02f commit a2c1b4e

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

crates/sui-futures/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
pub mod future;
55
pub mod service;
66
pub mod stream;
7+
pub mod task;

crates/sui-futures/src/task.rs

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) Mysten Labs, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
use std::{
5+
pin::Pin,
6+
task::{Context, Poll},
7+
};
8+
9+
use tokio::task::{JoinError, JoinHandle};
10+
11+
/// A wrapper around `JoinHandle` that aborts the task when dropped.
12+
///
13+
/// The abort on drop does not wait for the task to finish, it simply sends the abort signal.
14+
#[must_use = "Dropping the handle aborts the task immediately"]
15+
#[derive(Debug)]
16+
pub struct TaskGuard<T>(JoinHandle<T>);
17+
18+
impl<T> TaskGuard<T> {
19+
pub fn new(handle: JoinHandle<T>) -> Self {
20+
Self(handle)
21+
}
22+
}
23+
24+
impl<T> Future for TaskGuard<T> {
25+
type Output = Result<T, JoinError>;
26+
27+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
28+
Pin::new(&mut self.0).poll(cx)
29+
}
30+
}
31+
32+
impl<T> AsRef<JoinHandle<T>> for TaskGuard<T> {
33+
fn as_ref(&self) -> &JoinHandle<T> {
34+
&self.0
35+
}
36+
}
37+
38+
impl<T> Drop for TaskGuard<T> {
39+
fn drop(&mut self) {
40+
self.0.abort();
41+
}
42+
}
43+
44+
#[cfg(test)]
45+
mod tests {
46+
use std::time::Duration;
47+
48+
use tokio::sync::oneshot;
49+
50+
use super::*;
51+
52+
#[tokio::test]
53+
async fn test_abort_on_drop() {
54+
let (mut tx, rx) = oneshot::channel::<()>();
55+
56+
let guard = TaskGuard::new(tokio::spawn(async move {
57+
let _ = rx.await;
58+
}));
59+
60+
// When the guard is dropped, the task should be aborted, cleaning up its future, which
61+
// will close the receiving side of the channel.
62+
drop(guard);
63+
tokio::time::timeout(Duration::from_millis(100), tx.closed())
64+
.await
65+
.unwrap();
66+
}
67+
}

0 commit comments

Comments
 (0)