diff --git a/lightning-block-sync/src/gossip.rs b/lightning-block-sync/src/gossip.rs index 0fe221b9231..fb06ca3860f 100644 --- a/lightning-block-sync/src/gossip.rs +++ b/lightning-block-sync/src/gossip.rs @@ -47,8 +47,12 @@ pub trait UtxoSource: BlockSource + 'static { pub struct TokioSpawner; #[cfg(feature = "tokio")] impl FutureSpawner for TokioSpawner { - fn spawn + Send + 'static>(&self, future: T) { - tokio::spawn(future); + type E = tokio::task::JoinError; + type SpawnedFutureResult = tokio::task::JoinHandle; + fn spawn + Send + 'static>( + &self, future: F, + ) -> Self::SpawnedFutureResult { + tokio::spawn(future) } } @@ -273,7 +277,7 @@ where let gossiper = Arc::clone(&self.gossiper); let block_cache = Arc::clone(&self.block_cache); let pmw = Arc::clone(&self.peer_manager_wake); - self.spawn.spawn(async move { + let _ = self.spawn.spawn(async move { let res = Self::retrieve_utxo(source, block_cache, short_channel_id).await; fut.resolve(gossiper.network_graph(), &*gossiper, res); (pmw)(); diff --git a/lightning/src/util/async_poll.rs b/lightning/src/util/async_poll.rs index 3edfd5211fe..c39689094f3 100644 --- a/lightning/src/util/async_poll.rs +++ b/lightning/src/util/async_poll.rs @@ -16,26 +16,94 @@ use core::marker::Unpin; use core::pin::Pin; use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; -pub(crate) enum ResultFuture>, E: Unpin> { +pub(crate) enum ResultFuture + Unpin, O> { Pending(F), - Ready(Result<(), E>), + Ready(O), } -pub(crate) struct MultiResultFuturePoller> + Unpin, E: Unpin> { - futures_state: Vec>, +pub(crate) struct TwoFutureJoiner + Unpin, BF: Future + Unpin> { + a: Option>, + b: Option>, } -impl> + Unpin, E: Unpin> MultiResultFuturePoller { - pub fn new(futures_state: Vec>) -> Self { +impl + Unpin, BF: Future + Unpin> TwoFutureJoiner { + pub fn new(future_a: AF, future_b: BF) -> Self { + Self { + a: Some(ResultFuture::Pending(future_a)), + b: Some(ResultFuture::Pending(future_b)), + } + } +} + +impl + Unpin, BF: Future + Unpin> Future for TwoFutureJoiner { + type Output = (AO, BO); + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<(AO, BO)> { + let mut have_pending_futures = false; + // SAFETY: While we are pinned, we can't get direct access to our internal state because we + // aren't `Unpin`. However, we don't actually need the `Pin` - we only use it below on the + // `Future` in the `ResultFuture::Pending` case, and the `Future` is bound by `Unpin`. + // Thus, the `Pin` is not actually used, and its safe to bypass it and access the inner + // reference directly. + let state = unsafe { &mut self.get_unchecked_mut() }; + macro_rules! poll_future { + ($future: ident) => { + match state.$future { + Some(ResultFuture::Pending(ref mut fut)) => match Pin::new(fut).poll(cx) { + Poll::Ready(res) => { + state.$future = Some(ResultFuture::Ready(res)); + }, + Poll::Pending => { + have_pending_futures = true; + }, + }, + Some(ResultFuture::Ready(_)) => {}, + None => { + debug_assert!(false, "Future polled after Ready"); + return Poll::Pending; + }, + } + }; + } + poll_future!(a); + poll_future!(b); + + if have_pending_futures { + Poll::Pending + } else { + Poll::Ready(( + match state.a.take() { + Some(ResultFuture::Ready(a)) => a, + _ => unreachable!(), + }, + match state.b.take() { + Some(ResultFuture::Ready(b)) => b, + _ => unreachable!(), + } + )) + } + } +} + +pub(crate) struct MultiResultFuturePoller + Unpin, O> { + futures_state: Vec>, +} + +impl + Unpin, O> MultiResultFuturePoller { + pub fn new(futures_state: Vec>) -> Self { Self { futures_state } } } -impl> + Unpin, E: Unpin> Future for MultiResultFuturePoller { - type Output = Vec>; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { +impl + Unpin, O> Future for MultiResultFuturePoller { + type Output = Vec; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut have_pending_futures = false; - let futures_state = &mut self.get_mut().futures_state; + // SAFETY: While we are pinned, we can't get direct access to `futures_state` because we + // aren't `Unpin`. However, we don't actually need the `Pin` - we only use it below on the + // `Future` in the `ResultFuture::Pending` case, and the `Future` is bound by `Unpin`. + // Thus, the `Pin` is not actually used, and its safe to bypass it and access the inner + // reference directly. + let futures_state = unsafe { &mut self.get_unchecked_mut().futures_state }; for state in futures_state.iter_mut() { match state { ResultFuture::Pending(ref mut fut) => match Pin::new(fut).poll(cx) { diff --git a/lightning/src/util/native_async.rs b/lightning/src/util/native_async.rs index dc26cb42bd0..ecc39c215f5 100644 --- a/lightning/src/util/native_async.rs +++ b/lightning/src/util/native_async.rs @@ -8,21 +8,38 @@ //! environment. #[cfg(all(test, feature = "std"))] -use crate::sync::Mutex; +use crate::sync::{Arc, Mutex}; use crate::util::async_poll::{MaybeSend, MaybeSync}; #[cfg(all(test, not(feature = "std")))] -use core::cell::RefCell; +use core::cell::{Rc, RefCell}; +#[cfg(test)] +use core::convert::Infallible; use core::future::Future; #[cfg(test)] use core::pin::Pin; +#[cfg(test)] +use core::task::{Context, Poll}; -/// A generic trait which is able to spawn futures in the background. +/// A generic trait which is able to spawn futures to be polled in the background. +/// +/// When the spawned future completes, the returned [`Self::SpawnedFutureResult`] should resolve +/// with the output of the spawned future. +/// +/// Spawned futures must be polled independently in the background even if the returned +/// [`Self::SpawnedFutureResult`] is dropped without being polled. This matches the semantics of +/// `tokio::spawn`. pub trait FutureSpawner: MaybeSend + MaybeSync + 'static { + /// The error type of [`Self::SpawnedFutureResult`]. + type E; + /// The result of [`Self::spawn`], a future which completes when the spawned future completes. + type SpawnedFutureResult: Future> + Unpin; /// Spawns the given future as a background task. /// /// This method MUST NOT block on the given future immediately. - fn spawn + MaybeSend + 'static>(&self, future: T); + fn spawn + MaybeSend + 'static>( + &self, future: T, + ) -> Self::SpawnedFutureResult; } #[cfg(test)] @@ -37,6 +54,69 @@ pub(crate) struct FutureQueue(Mutex>>>); #[cfg(all(test, not(feature = "std")))] pub(crate) struct FutureQueue(RefCell>>>); +#[cfg(all(test, feature = "std"))] +pub struct FutureQueueCompletion(Arc>>); +#[cfg(all(test, not(feature = "std")))] +pub struct FutureQueueCompletion(Rc>>); + +#[cfg(all(test, feature = "std"))] +impl FutureQueueCompletion { + fn new() -> Self { + Self(Arc::new(Mutex::new(None))) + } + + fn complete(&self, o: O) { + *self.0.lock().unwrap() = Some(o); + } +} + +#[cfg(all(test, feature = "std"))] +impl Clone for FutureQueueCompletion { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +#[cfg(all(test, not(feature = "std")))] +impl FutureQueueCompletion { + fn new() -> Self { + Self(Rc::new(Mutex::new(None))) + } + + fn complete(&self, o: O) { + *self.0.lock().unwrap() = Some(o); + } +} + +#[cfg(all(test, not(feature = "std")))] +impl Clone for FutureQueueCompletion { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +#[cfg(all(test, feature = "std"))] +impl Future for FutureQueueCompletion { + type Output = Result; + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match Pin::into_inner(self).0.lock().unwrap().take() { + None => Poll::Pending, + Some(o) => Poll::Ready(Ok(o)), + } + } +} + +#[cfg(all(test, not(feature = "std")))] +impl Future for FutureQueueCompletion { + type Output = Result; + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match Pin::into_inner(self).0.get_mut().take() { + None => Poll::Pending, + Some(o) => Poll::Ready(Ok(o)), + } + } +} + #[cfg(test)] impl FutureQueue { pub(crate) fn new() -> Self { @@ -84,7 +164,16 @@ impl FutureQueue { #[cfg(test)] impl FutureSpawner for FutureQueue { - fn spawn + MaybeSend + 'static>(&self, future: T) { + type E = Infallible; + type SpawnedFutureResult = FutureQueueCompletion; + fn spawn + MaybeSend + 'static>( + &self, f: F, + ) -> FutureQueueCompletion { + let completion = FutureQueueCompletion::new(); + let compl_ref = completion.clone(); + let future = async move { + compl_ref.complete(f.await); + }; #[cfg(feature = "std")] { self.0.lock().unwrap().push(Box::pin(future)); @@ -93,6 +182,7 @@ impl FutureSpawner for FutureQueue { { self.0.borrow_mut().push(Box::pin(future)); } + completion } } @@ -100,7 +190,16 @@ impl FutureSpawner for FutureQueue { impl + MaybeSend + MaybeSync + 'static> FutureSpawner for D { - fn spawn + MaybeSend + 'static>(&self, future: T) { + type E = Infallible; + type SpawnedFutureResult = FutureQueueCompletion; + fn spawn + MaybeSend + 'static>( + &self, f: F, + ) -> FutureQueueCompletion { + let completion = FutureQueueCompletion::new(); + let compl_ref = completion.clone(); + let future = async move { + compl_ref.complete(f.await); + }; #[cfg(feature = "std")] { self.0.lock().unwrap().push(Box::pin(future)); @@ -109,5 +208,6 @@ impl + MaybeSend + MaybeSync + 'static { self.0.borrow_mut().push(Box::pin(future)); } + completion } } diff --git a/lightning/src/util/persist.rs b/lightning/src/util/persist.rs index 0b4ba190740..78947c4214f 100644 --- a/lightning/src/util/persist.rs +++ b/lightning/src/util/persist.rs @@ -16,6 +16,7 @@ use alloc::sync::Arc; use bitcoin::hashes::hex::FromHex; use bitcoin::{BlockHash, Txid}; +use core::convert::Infallible; use core::future::Future; use core::mem; use core::ops::Deref; @@ -34,7 +35,9 @@ use crate::chain::transaction::OutPoint; use crate::ln::types::ChannelId; use crate::sign::{ecdsa::EcdsaChannelSigner, EntropySource, SignerProvider}; use crate::sync::Mutex; -use crate::util::async_poll::{dummy_waker, MaybeSend, MaybeSync}; +use crate::util::async_poll::{ + dummy_waker, MaybeSend, MaybeSync, MultiResultFuturePoller, ResultFuture, TwoFutureJoiner, +}; use crate::util::logger::Logger; use crate::util::native_async::FutureSpawner; use crate::util::ser::{Readable, ReadableArgs, Writeable}; @@ -405,7 +408,11 @@ where struct PanicingSpawner; impl FutureSpawner for PanicingSpawner { - fn spawn + MaybeSend + 'static>(&self, _: T) { + type E = Infallible; + type SpawnedFutureResult = Box> + Unpin>; + fn spawn + MaybeSend + 'static>( + &self, _: T, + ) -> Self::SpawnedFutureResult { unreachable!(); } } @@ -486,15 +493,6 @@ fn poll_sync_future(future: F) -> F::Output { /// list channel monitors themselves and load channels individually using /// [`MonitorUpdatingPersister::read_channel_monitor_with_updates`]. /// -/// ## EXTREMELY IMPORTANT -/// -/// It is extremely important that your [`KVStoreSync::read`] implementation uses the -/// [`io::ErrorKind::NotFound`] variant correctly: that is, when a file is not found, and _only_ in -/// that circumstance (not when there is really a permissions error, for example). This is because -/// neither channel monitor reading function lists updates. Instead, either reads the monitor, and -/// using its stored `update_id`, synthesizes update storage keys, and tries them in sequence until -/// one is not found. All _other_ errors will be bubbled up in the function's [`Result`]. -/// /// # Pruning stale channel updates /// /// Stale updates are pruned when the consolidation threshold is reached according to `maximum_pending_updates`. @@ -562,10 +560,6 @@ where } /// Reads all stored channel monitors, along with any stored updates for them. - /// - /// It is extremely important that your [`KVStoreSync::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. pub fn read_all_channel_monitors_with_updates( &self, ) -> Result< @@ -577,10 +571,6 @@ where /// Read a single channel monitor, along with any stored updates for it. /// - /// It is extremely important that your [`KVStoreSync::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. - /// /// For `monitor_key`, channel storage keys can be the channel's funding [`OutPoint`], with an /// underscore `_` between txid and index for v1 channels. For example, given: /// @@ -771,9 +761,9 @@ where /// Reads all stored channel monitors, along with any stored updates for them. /// - /// It is extremely important that your [`KVStore::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. + /// While the reads themselves are performend in parallel, deserializing the + /// [`ChannelMonitor`]s is not. For large [`ChannelMonitor`]s actively used for forwarding, + /// this may substantially limit the parallelism of this method. pub async fn read_all_channel_monitors_with_updates( &self, ) -> Result< @@ -783,22 +773,69 @@ where let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; let monitor_list = self.0.kv_store.list(primary, secondary).await?; - let mut res = Vec::with_capacity(monitor_list.len()); + let mut futures = Vec::with_capacity(monitor_list.len()); for monitor_key in monitor_list { - let result = - self.0.maybe_read_channel_monitor_with_updates(monitor_key.as_str()).await?; - if let Some(read_res) = result { + futures.push(ResultFuture::Pending(Box::pin(async move { + self.0.maybe_read_channel_monitor_with_updates(monitor_key.as_str()).await + }))); + } + let future_results = MultiResultFuturePoller::new(futures).await; + let mut res = Vec::with_capacity(future_results.len()); + for result in future_results { + if let Some(read_res) = result? { res.push(read_res); } } Ok(res) } - /// Read a single channel monitor, along with any stored updates for it. + /// Reads all stored channel monitors, along with any stored updates for them, in parallel. /// - /// It is extremely important that your [`KVStoreSync::read`] implementation uses the - /// [`io::ErrorKind::NotFound`] variant correctly. For more information, please see the - /// documentation for [`MonitorUpdatingPersister`]. + /// Because deserializing large [`ChannelMonitor`]s from forwarding nodes is often CPU-bound, + /// this version of [`Self::read_all_channel_monitors_with_updates`] uses the [`FutureSpawner`] + /// to parallelize deserialization as well as the IO operations. + /// + /// Because [`FutureSpawner`] requires that the spawned future be `'static` (matching `tokio` + /// and other multi-threaded runtime requirements), this method requires that `self` be an + /// `Arc` that can live for `'static` and be sent and accessed across threads. + pub async fn read_all_channel_monitors_with_updates_parallel( + self: &Arc, + ) -> Result< + Vec<(BlockHash, ChannelMonitor<::EcdsaSigner>)>, + io::Error, + > where + K: MaybeSend + MaybeSync + 'static, + L: MaybeSend + MaybeSync + 'static, + ES: MaybeSend + MaybeSync + 'static, + SP: MaybeSend + MaybeSync + 'static, + BI: MaybeSend + MaybeSync + 'static, + FE: MaybeSend + MaybeSync + 'static, + ::EcdsaSigner: MaybeSend, + { + let primary = CHANNEL_MONITOR_PERSISTENCE_PRIMARY_NAMESPACE; + let secondary = CHANNEL_MONITOR_PERSISTENCE_SECONDARY_NAMESPACE; + let monitor_list = self.0.kv_store.list(primary, secondary).await?; + let mut futures = Vec::with_capacity(monitor_list.len()); + for monitor_key in monitor_list { + let us = Arc::clone(&self); + futures.push(ResultFuture::Pending(self.0.future_spawner.spawn(async move { + us.0.maybe_read_channel_monitor_with_updates(monitor_key.as_str()).await + }))); + } + let future_results = MultiResultFuturePoller::new(futures).await; + let mut res = Vec::with_capacity(future_results.len()); + for result in future_results { + match result { + Err(_) => return Err(io::Error::new(io::ErrorKind::Other, "Future was cancelled")), + Ok(Err(e)) => return Err(e), + Ok(Ok(Some(read_res))) => res.push(read_res), + Ok(Ok(None)) => {}, + } + } + Ok(res) + } + + /// Read a single channel monitor, along with any stored updates for it. /// /// For `monitor_key`, channel storage keys can be the channel's funding [`OutPoint`], with an /// underscore `_` between txid and index for v1 channels. For example, given: @@ -858,7 +895,7 @@ where let future = inner.persist_new_channel(monitor_name, monitor); let channel_id = monitor.channel_id(); let completion = (monitor.channel_id(), monitor.get_latest_update_id()); - self.0.future_spawner.spawn(async move { + let _ = self.0.future_spawner.spawn(async move { match future.await { Ok(()) => inner.async_completed_updates.lock().unwrap().push(completion), Err(e) => { @@ -886,7 +923,7 @@ where None }; let inner = Arc::clone(&self.0); - self.0.future_spawner.spawn(async move { + let _ = self.0.future_spawner.spawn(async move { match future.await { Ok(()) => if let Some(completion) = completion { inner.async_completed_updates.lock().unwrap().push(completion); @@ -903,7 +940,7 @@ where pub(crate) fn spawn_async_archive_persisted_channel(&self, monitor_name: MonitorName) { let inner = Arc::clone(&self.0); - self.0.future_spawner.spawn(async move { + let _ = self.0.future_spawner.spawn(async move { inner.archive_persisted_channel(monitor_name).await; }); } @@ -945,28 +982,29 @@ where io::Error, > { let monitor_name = MonitorName::from_str(monitor_key)?; - let read_res = self.maybe_read_monitor(&monitor_name, monitor_key).await?; - let (block_hash, monitor) = match read_res { + // TODO: After an MSRV bump we should be able to use the pin macro rather than Box::pin + let read_future = Box::pin(self.maybe_read_monitor(&monitor_name, monitor_key)); + let list_future = + Box::pin(self.kv_store.list(CHANNEL_MONITOR_UPDATE_PERSISTENCE_PRIMARY_NAMESPACE, monitor_key)); + let (read_res, list_res) = TwoFutureJoiner::new(read_future, list_future).await; + let (block_hash, monitor) = match read_res? { Some(res) => res, None => return Ok(None), }; - let mut current_update_id = monitor.get_latest_update_id(); - // TODO: Parallelize this loop by speculatively reading a batch of updates - loop { - current_update_id = match current_update_id.checked_add(1) { - Some(next_update_id) => next_update_id, - None => break, - }; - let update_name = UpdateName::from(current_update_id); - let update = match self.read_monitor_update(monitor_key, &update_name).await { - Ok(update) => update, - Err(err) if err.kind() == io::ErrorKind::NotFound => { - // We can't find any more updates, so we are done. - break; - }, - Err(err) => return Err(err), - }; - + let current_update_id = monitor.get_latest_update_id(); + let updates: Result, _> = + list_res?.into_iter().map(|name| UpdateName::new(name)).collect(); + let mut updates = updates?; + updates.sort_unstable(); + let updates_to_load = updates.iter().filter(|update| update.0 > current_update_id); + let mut update_futures = Vec::with_capacity(updates_to_load.clone().count()); + for update_name in updates_to_load { + update_futures.push(ResultFuture::Pending(Box::pin(async move { + (update_name, self.read_monitor_update(monitor_key, update_name).await) + }))); + } + for (update_name, update_res) in MultiResultFuturePoller::new(update_futures).await { + let update = update_res?; monitor .update_monitor(&update, &self.broadcaster, &self.fee_estimator, &self.logger) .map_err(|e| { @@ -1350,7 +1388,7 @@ impl core::fmt::Display for MonitorName { /// let monitor_name = "some_monitor_name"; /// let storage_key = format!("channel_monitor_updates/{}/{}", monitor_name, update_name.as_str()); /// ``` -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct UpdateName(pub u64, String); impl UpdateName {