Skip to content

Commit a866e26

Browse files
committed
Allow FutureSpawner to return the result of the spawned future
`tokio::spawn` can be use both to spawn a forever-running background task or to spawn a task which gets `poll`ed independently and eventually returns a result which the callsite wants. In LDK, we have only ever needed the first, and thus didn't bother defining a return type for `FutureSpawner::spawn`. However, in the next commit we'll start using `FutureSpawner` in a context where we actually do want the spawned future's result. Thus, here, we add a result output to `FutureSpawner::spawn`, mirroring the `tokio::spawn` API.
1 parent afad0d3 commit a866e26

File tree

3 files changed

+123
-13
lines changed

3 files changed

+123
-13
lines changed

lightning-block-sync/src/gossip.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ pub trait UtxoSource: BlockSource + 'static {
4747
pub struct TokioSpawner;
4848
#[cfg(feature = "tokio")]
4949
impl FutureSpawner for TokioSpawner {
50-
fn spawn<T: Future<Output = ()> + Send + 'static>(&self, future: T) {
51-
tokio::spawn(future);
50+
type E = tokio::task::JoinError;
51+
type SpawnedFutureResult<O> = tokio::task::JoinHandle<O>;
52+
fn spawn<O: Send + 'static, F: Future<Output = O> + Send + 'static>(
53+
&self, future: F,
54+
) -> Self::SpawnedFutureResult<O> {
55+
tokio::spawn(future)
5256
}
5357
}
5458

@@ -273,7 +277,7 @@ where
273277
let gossiper = Arc::clone(&self.gossiper);
274278
let block_cache = Arc::clone(&self.block_cache);
275279
let pmw = Arc::clone(&self.peer_manager_wake);
276-
self.spawn.spawn(async move {
280+
let _ = self.spawn.spawn(async move {
277281
let res = Self::retrieve_utxo(source, block_cache, short_channel_id).await;
278282
fut.resolve(gossiper.network_graph(), &*gossiper, res);
279283
(pmw)();

lightning/src/util/native_async.rs

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,39 @@
88
//! environment.
99
1010
#[cfg(all(test, feature = "std"))]
11-
use crate::sync::Mutex;
11+
use crate::sync::{Arc, Mutex};
1212
use crate::util::async_poll::{MaybeSend, MaybeSync};
1313

1414
#[cfg(all(test, not(feature = "std")))]
15-
use core::cell::RefCell;
15+
use core::cell::{Rc, RefCell};
16+
#[cfg(test)]
17+
use core::convert::Infallible;
1618
use core::future::Future;
1719
#[cfg(test)]
1820
use core::pin::Pin;
21+
#[cfg(test)]
22+
use core::task::{Context, Poll};
1923

20-
/// A generic trait which is able to spawn futures in the background.
24+
/// A generic trait which is able to spawn futures to be polled in the background.
25+
///
26+
/// When the spawned future completes, the returned [`Self::SpawnedFutureResult`] should resolve
27+
/// with the output of the spawned future.
28+
///
29+
/// Spawned futures must be polled independently in the background even if the returned
30+
/// [`Self::SpawnedFutureResult`] is dropped without being polled. This matches the semantics of
31+
/// `tokio::spawn`.
2132
pub trait FutureSpawner: MaybeSend + MaybeSync + 'static {
33+
/// The error type of [`Self::SpawnedFutureResult`]. This can be used to indicate that the
34+
/// spawned future was cancelled or panicked.
35+
type E;
36+
/// The result of [`Self::spawn`], a future which completes when the spawned future completes.
37+
type SpawnedFutureResult<O>: Future<Output = Result<O, Self::E>> + Unpin;
2238
/// Spawns the given future as a background task.
2339
///
2440
/// This method MUST NOT block on the given future immediately.
25-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T);
41+
fn spawn<O: MaybeSend + 'static, T: Future<Output = O> + MaybeSend + 'static>(
42+
&self, future: T,
43+
) -> Self::SpawnedFutureResult<O>;
2644
}
2745

2846
#[cfg(test)]
@@ -37,6 +55,69 @@ pub(crate) struct FutureQueue(Mutex<Vec<Pin<Box<dyn MaybeSendableFuture>>>>);
3755
#[cfg(all(test, not(feature = "std")))]
3856
pub(crate) struct FutureQueue(RefCell<Vec<Pin<Box<dyn MaybeSendableFuture>>>>);
3957

58+
#[cfg(all(test, feature = "std"))]
59+
pub struct FutureQueueCompletion<O>(Arc<Mutex<Option<O>>>);
60+
#[cfg(all(test, not(feature = "std")))]
61+
pub struct FutureQueueCompletion<O>(Rc<RefCell<Option<O>>>);
62+
63+
#[cfg(all(test, feature = "std"))]
64+
impl<O> FutureQueueCompletion<O> {
65+
fn new() -> Self {
66+
Self(Arc::new(Mutex::new(None)))
67+
}
68+
69+
fn complete(&self, o: O) {
70+
*self.0.lock().unwrap() = Some(o);
71+
}
72+
}
73+
74+
#[cfg(all(test, feature = "std"))]
75+
impl<O> Clone for FutureQueueCompletion<O> {
76+
fn clone(&self) -> Self {
77+
Self(self.0.clone())
78+
}
79+
}
80+
81+
#[cfg(all(test, not(feature = "std")))]
82+
impl<O> FutureQueueCompletion<O> {
83+
fn new() -> Self {
84+
Self(Rc::new(Mutex::new(None)))
85+
}
86+
87+
fn complete(&self, o: O) {
88+
*self.0.lock().unwrap() = Some(o);
89+
}
90+
}
91+
92+
#[cfg(all(test, not(feature = "std")))]
93+
impl<O> Clone for FutureQueueCompletion<O> {
94+
fn clone(&self) -> Self {
95+
Self(self.0.clone())
96+
}
97+
}
98+
99+
#[cfg(all(test, feature = "std"))]
100+
impl<O> Future for FutureQueueCompletion<O> {
101+
type Output = Result<O, Infallible>;
102+
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<O, Infallible>> {
103+
match Pin::into_inner(self).0.lock().unwrap().take() {
104+
None => Poll::Pending,
105+
Some(o) => Poll::Ready(Ok(o)),
106+
}
107+
}
108+
}
109+
110+
#[cfg(all(test, not(feature = "std")))]
111+
impl<O> Future for FutureQueueCompletion<O> {
112+
type Output = Result<O, Infallible>;
113+
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<O, Infallible>> {
114+
match Pin::into_inner(self).0.get_mut().take() {
115+
None => Poll::Pending,
116+
Some(o) => Poll::Ready(Ok(o)),
117+
}
118+
}
119+
}
120+
40121
#[cfg(test)]
41122
impl FutureQueue {
42123
pub(crate) fn new() -> Self {
@@ -84,7 +165,16 @@ impl FutureQueue {
84165

85166
#[cfg(test)]
86167
impl FutureSpawner for FutureQueue {
87-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T) {
168+
type E = Infallible;
169+
type SpawnedFutureResult<O> = FutureQueueCompletion<O>;
170+
fn spawn<O: MaybeSend + 'static, F: Future<Output = O> + MaybeSend + 'static>(
171+
&self, f: F,
172+
) -> FutureQueueCompletion<O> {
173+
let completion = FutureQueueCompletion::new();
174+
let compl_ref = completion.clone();
175+
let future = async move {
176+
compl_ref.complete(f.await);
177+
};
88178
#[cfg(feature = "std")]
89179
{
90180
self.0.lock().unwrap().push(Box::pin(future));
@@ -93,14 +183,24 @@ impl FutureSpawner for FutureQueue {
93183
{
94184
self.0.borrow_mut().push(Box::pin(future));
95185
}
186+
completion
96187
}
97188
}
98189

99190
#[cfg(test)]
100191
impl<D: core::ops::Deref<Target = FutureQueue> + MaybeSend + MaybeSync + 'static> FutureSpawner
101192
for D
102193
{
103-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, future: T) {
194+
type E = Infallible;
195+
type SpawnedFutureResult<O> = FutureQueueCompletion<O>;
196+
fn spawn<O: MaybeSend + 'static, F: Future<Output = O> + MaybeSend + 'static>(
197+
&self, f: F,
198+
) -> FutureQueueCompletion<O> {
199+
let completion = FutureQueueCompletion::new();
200+
let compl_ref = completion.clone();
201+
let future = async move {
202+
compl_ref.complete(f.await);
203+
};
104204
#[cfg(feature = "std")]
105205
{
106206
self.0.lock().unwrap().push(Box::pin(future));
@@ -109,5 +209,6 @@ impl<D: core::ops::Deref<Target = FutureQueue> + MaybeSend + MaybeSync + 'static
109209
{
110210
self.0.borrow_mut().push(Box::pin(future));
111211
}
212+
completion
112213
}
113214
}

lightning/src/util/persist.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use alloc::sync::Arc;
1616
use bitcoin::hashes::hex::FromHex;
1717
use bitcoin::{BlockHash, Txid};
1818

19+
use core::convert::Infallible;
1920
use core::future::Future;
2021
use core::mem;
2122
use core::ops::Deref;
@@ -407,7 +408,11 @@ where
407408

408409
struct PanicingSpawner;
409410
impl FutureSpawner for PanicingSpawner {
410-
fn spawn<T: Future<Output = ()> + MaybeSend + 'static>(&self, _: T) {
411+
type E = Infallible;
412+
type SpawnedFutureResult<O> = Box<dyn Future<Output = Result<O, Infallible>> + Unpin>;
413+
fn spawn<O, T: Future<Output = O> + MaybeSend + 'static>(
414+
&self, _: T,
415+
) -> Self::SpawnedFutureResult<O> {
411416
unreachable!();
412417
}
413418
}
@@ -865,7 +870,7 @@ where
865870
let future = inner.persist_new_channel(monitor_name, monitor);
866871
let channel_id = monitor.channel_id();
867872
let completion = (monitor.channel_id(), monitor.get_latest_update_id());
868-
self.0.future_spawner.spawn(async move {
873+
let _ = self.0.future_spawner.spawn(async move {
869874
match future.await {
870875
Ok(()) => inner.async_completed_updates.lock().unwrap().push(completion),
871876
Err(e) => {
@@ -893,7 +898,7 @@ where
893898
None
894899
};
895900
let inner = Arc::clone(&self.0);
896-
self.0.future_spawner.spawn(async move {
901+
let _ = self.0.future_spawner.spawn(async move {
897902
match future.await {
898903
Ok(()) => if let Some(completion) = completion {
899904
inner.async_completed_updates.lock().unwrap().push(completion);
@@ -910,7 +915,7 @@ where
910915

911916
pub(crate) fn spawn_async_archive_persisted_channel(&self, monitor_name: MonitorName) {
912917
let inner = Arc::clone(&self.0);
913-
self.0.future_spawner.spawn(async move {
918+
let _ = self.0.future_spawner.spawn(async move {
914919
inner.archive_persisted_channel(monitor_name).await;
915920
});
916921
}

0 commit comments

Comments
 (0)