diff --git a/bindings/ldk_node.udl b/bindings/ldk_node.udl index 26480ca4b..f3560ec09 100644 --- a/bindings/ldk_node.udl +++ b/bindings/ldk_node.udl @@ -64,7 +64,7 @@ dictionary LogRecord { [Trait, WithForeign] interface LogWriter { - void log(LogRecord record); + void log(LogRecord record); }; interface Builder { @@ -161,8 +161,8 @@ interface Node { [Enum] interface Bolt11InvoiceDescription { - Hash(string hash); - Direct(string description); + Hash(string hash); + Direct(string description); }; interface Bolt11Payment { @@ -335,6 +335,7 @@ enum BuildError { "InvalidListeningAddresses", "InvalidAnnouncementAddresses", "InvalidNodeAlias", + "RuntimeSetupFailed", "ReadFailed", "WriteFailed", "StoragePathAccessFailed", diff --git a/src/builder.rs b/src/builder.rs index 1152f18c3..729cefe1b 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -28,6 +28,7 @@ use crate::liquidity::{ use crate::logger::{log_error, log_info, LdkLogger, LogLevel, LogWriter, Logger}; use crate::message_handler::NodeCustomMessageHandler; use crate::peer_store::PeerStore; +use crate::runtime::Runtime; use crate::tx_broadcaster::TransactionBroadcaster; use crate::types::{ ChainMonitor, ChannelManager, DynStore, GossipSync, Graph, KeysManager, MessageRouter, @@ -168,6 +169,8 @@ pub enum BuildError { InvalidAnnouncementAddresses, /// The provided alias is invalid. InvalidNodeAlias, + /// An attempt to setup a runtime has failed. + RuntimeSetupFailed, /// We failed to read data from the [`KVStore`]. /// /// [`KVStore`]: lightning::util::persist::KVStore @@ -205,6 +208,7 @@ impl fmt::Display for BuildError { Self::InvalidAnnouncementAddresses => { write!(f, "Given announcement addresses are invalid.") }, + Self::RuntimeSetupFailed => write!(f, "Failed to setup a runtime."), Self::ReadFailed => write!(f, "Failed to read from store."), Self::WriteFailed => write!(f, "Failed to write to store."), Self::StoragePathAccessFailed => write!(f, "Failed to access the given storage path."), @@ -236,6 +240,7 @@ pub struct NodeBuilder { gossip_source_config: Option, liquidity_source_config: Option, log_writer_config: Option, + runtime_handle: Option, } impl NodeBuilder { @@ -252,6 +257,7 @@ impl NodeBuilder { let gossip_source_config = None; let liquidity_source_config = None; let log_writer_config = None; + let runtime_handle = None; Self { config, entropy_source_config, @@ -259,9 +265,20 @@ impl NodeBuilder { gossip_source_config, liquidity_source_config, log_writer_config, + runtime_handle, } } + /// Configures the [`Node`] instance to (re-)use a specific `tokio` runtime. + /// + /// If not provided, the node will spawn its own runtime or reuse any outer runtime context it + /// can detect. + #[cfg_attr(feature = "uniffi", allow(dead_code))] + pub fn set_runtime(&mut self, runtime_handle: tokio::runtime::Handle) -> &mut Self { + self.runtime_handle = Some(runtime_handle); + self + } + /// Configures the [`Node`] instance to source its wallet entropy from a seed file on disk. /// /// If the given file does not exist a new random seed file will be generated and @@ -650,6 +667,15 @@ impl NodeBuilder { ) -> Result { let logger = setup_logger(&self.log_writer_config, &self.config)?; + let runtime = if let Some(handle) = self.runtime_handle.as_ref() { + Arc::new(Runtime::with_handle(handle.clone())) + } else { + Arc::new(Runtime::new().map_err(|e| { + log_error!(logger, "Failed to setup tokio runtime: {}", e); + BuildError::RuntimeSetupFailed + })?) + }; + let seed_bytes = seed_bytes_from_config( &self.config, self.entropy_source_config.as_ref(), @@ -678,6 +704,7 @@ impl NodeBuilder { self.gossip_source_config.as_ref(), self.liquidity_source_config.as_ref(), seed_bytes, + runtime, logger, Arc::new(vss_store), ) @@ -687,6 +714,15 @@ impl NodeBuilder { pub fn build_with_store(&self, kv_store: Arc) -> Result { let logger = setup_logger(&self.log_writer_config, &self.config)?; + let runtime = if let Some(handle) = self.runtime_handle.as_ref() { + Arc::new(Runtime::with_handle(handle.clone())) + } else { + Arc::new(Runtime::new().map_err(|e| { + log_error!(logger, "Failed to setup tokio runtime: {}", e); + BuildError::RuntimeSetupFailed + })?) + }; + let seed_bytes = seed_bytes_from_config( &self.config, self.entropy_source_config.as_ref(), @@ -700,6 +736,7 @@ impl NodeBuilder { self.gossip_source_config.as_ref(), self.liquidity_source_config.as_ref(), seed_bytes, + runtime, logger, kv_store, ) @@ -1049,7 +1086,7 @@ fn build_with_store_internal( config: Arc, chain_data_source_config: Option<&ChainDataSourceConfig>, gossip_source_config: Option<&GossipSourceConfig>, liquidity_source_config: Option<&LiquiditySourceConfig>, seed_bytes: [u8; 64], - logger: Arc, kv_store: Arc, + runtime: Arc, logger: Arc, kv_store: Arc, ) -> Result { optionally_install_rustls_cryptoprovider(); @@ -1241,8 +1278,6 @@ fn build_with_store_internal( }, }; - let runtime = Arc::new(RwLock::new(None)); - // Initialize the ChainMonitor let chain_monitor: Arc = Arc::new(chainmonitor::ChainMonitor::new( Some(Arc::clone(&chain_source)), @@ -1637,6 +1672,8 @@ fn build_with_store_internal( let background_tasks = Mutex::new(None); let cancellable_background_tasks = Mutex::new(None); + let is_running = Arc::new(RwLock::new(false)); + Ok(Node { runtime, stop_sender, @@ -1664,6 +1701,7 @@ fn build_with_store_internal( scorer, peer_store, payment_store, + is_running, is_listening, node_metrics, }) diff --git a/src/chain/electrum.rs b/src/chain/electrum.rs index abbb758dd..b6d37409b 100644 --- a/src/chain/electrum.rs +++ b/src/chain/electrum.rs @@ -18,6 +18,7 @@ use crate::fee_estimator::{ }; use crate::io::utils::write_node_metrics; use crate::logger::{log_bytes, log_error, log_info, log_trace, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::NodeMetrics; @@ -86,7 +87,7 @@ impl ElectrumChainSource { } } - pub(super) fn start(&self, runtime: Arc) -> Result<(), Error> { + pub(super) fn start(&self, runtime: Arc) -> Result<(), Error> { self.electrum_runtime_status.write().unwrap().start( self.server_url.clone(), Arc::clone(&runtime), @@ -339,7 +340,7 @@ impl ElectrumRuntimeStatus { } pub(super) fn start( - &mut self, server_url: String, runtime: Arc, config: Arc, + &mut self, server_url: String, runtime: Arc, config: Arc, logger: Arc, ) -> Result<(), Error> { match self { @@ -403,15 +404,14 @@ struct ElectrumRuntimeClient { electrum_client: Arc, bdk_electrum_client: Arc>, tx_sync: Arc>>, - runtime: Arc, + runtime: Arc, config: Arc, logger: Arc, } impl ElectrumRuntimeClient { fn new( - server_url: String, runtime: Arc, config: Arc, - logger: Arc, + server_url: String, runtime: Arc, config: Arc, logger: Arc, ) -> Result { let electrum_config = ElectrumConfigBuilder::new() .retry(ELECTRUM_CLIENT_NUM_RETRIES) @@ -544,7 +544,6 @@ impl ElectrumRuntimeClient { let spawn_fut = self.runtime.spawn_blocking(move || electrum_client.transaction_broadcast(&tx)); - let timeout_fut = tokio::time::timeout(Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), spawn_fut); diff --git a/src/chain/mod.rs b/src/chain/mod.rs index a4ab2c76b..f3a29e984 100644 --- a/src/chain/mod.rs +++ b/src/chain/mod.rs @@ -19,6 +19,7 @@ use crate::config::{ use crate::fee_estimator::OnchainFeeEstimator; use crate::io::utils::write_node_metrics; use crate::logger::{log_debug, log_info, log_trace, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{Broadcaster, ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::{Error, NodeMetrics}; @@ -185,7 +186,7 @@ impl ChainSource { Self { kind, tx_broadcaster, logger } } - pub(crate) fn start(&self, runtime: Arc) -> Result<(), Error> { + pub(crate) fn start(&self, runtime: Arc) -> Result<(), Error> { match &self.kind { ChainSourceKind::Electrum(electrum_chain_source) => { electrum_chain_source.start(runtime)? diff --git a/src/event.rs b/src/event.rs index 22848bec1..ae81f50e9 100644 --- a/src/event.rs +++ b/src/event.rs @@ -29,6 +29,8 @@ use crate::io::{ }; use crate::logger::{log_debug, log_error, log_info, LdkLogger}; +use crate::runtime::Runtime; + use lightning::events::bump_transaction::BumpTransactionEvent; use lightning::events::{ClosureReason, PaymentPurpose, ReplayEvent}; use lightning::events::{Event as LdkEvent, PaymentFailureReason}; @@ -53,7 +55,7 @@ use core::future::Future; use core::task::{Poll, Waker}; use std::collections::VecDeque; use std::ops::Deref; -use std::sync::{Arc, Condvar, Mutex, RwLock}; +use std::sync::{Arc, Condvar, Mutex}; use std::time::Duration; /// An event emitted by [`Node`], which should be handled by the user. @@ -451,7 +453,7 @@ where liquidity_source: Option>>>, payment_store: Arc, peer_store: Arc>, - runtime: Arc>>>, + runtime: Arc, logger: L, config: Arc, } @@ -466,8 +468,8 @@ where channel_manager: Arc, connection_manager: Arc>, output_sweeper: Arc, network_graph: Arc, liquidity_source: Option>>>, - payment_store: Arc, peer_store: Arc>, - runtime: Arc>>>, logger: L, config: Arc, + payment_store: Arc, peer_store: Arc>, runtime: Arc, + logger: L, config: Arc, ) -> Self { Self { event_queue, @@ -1049,17 +1051,14 @@ where let forwarding_channel_manager = self.channel_manager.clone(); let min = time_forwardable.as_millis() as u64; - let runtime_lock = self.runtime.read().unwrap(); - debug_assert!(runtime_lock.is_some()); + let future = async move { + let millis_to_sleep = thread_rng().gen_range(min..min * 5) as u64; + tokio::time::sleep(Duration::from_millis(millis_to_sleep)).await; - if let Some(runtime) = runtime_lock.as_ref() { - runtime.spawn(async move { - let millis_to_sleep = thread_rng().gen_range(min..min * 5) as u64; - tokio::time::sleep(Duration::from_millis(millis_to_sleep)).await; + forwarding_channel_manager.process_pending_htlc_forwards(); + }; - forwarding_channel_manager.process_pending_htlc_forwards(); - }); - } + self.runtime.spawn(future); }, LdkEvent::SpendableOutputs { outputs, channel_id } => { match self.output_sweeper.track_spendable_outputs(outputs, channel_id, true, None) { @@ -1421,31 +1420,27 @@ where debug_assert!(false, "We currently don't handle BOLT12 invoices manually, so this event should never be emitted."); }, LdkEvent::ConnectionNeeded { node_id, addresses } => { - let runtime_lock = self.runtime.read().unwrap(); - debug_assert!(runtime_lock.is_some()); - - if let Some(runtime) = runtime_lock.as_ref() { - let spawn_logger = self.logger.clone(); - let spawn_cm = Arc::clone(&self.connection_manager); - runtime.spawn(async move { - for addr in &addresses { - match spawn_cm.connect_peer_if_necessary(node_id, addr.clone()).await { - Ok(()) => { - return; - }, - Err(e) => { - log_error!( - spawn_logger, - "Failed to establish connection to peer {}@{}: {}", - node_id, - addr, - e - ); - }, - } + let spawn_logger = self.logger.clone(); + let spawn_cm = Arc::clone(&self.connection_manager); + let future = async move { + for addr in &addresses { + match spawn_cm.connect_peer_if_necessary(node_id, addr.clone()).await { + Ok(()) => { + return; + }, + Err(e) => { + log_error!( + spawn_logger, + "Failed to establish connection to peer {}@{}: {}", + node_id, + addr, + e + ); + }, } - }); - } + } + }; + self.runtime.spawn(future); }, LdkEvent::BumpTransaction(bte) => { match bte { diff --git a/src/gossip.rs b/src/gossip.rs index a8a6e3831..1185f0718 100644 --- a/src/gossip.rs +++ b/src/gossip.rs @@ -7,7 +7,8 @@ use crate::chain::ChainSource; use crate::config::RGS_SYNC_TIMEOUT_SECS; -use crate::logger::{log_error, log_trace, LdkLogger, Logger}; +use crate::logger::{log_trace, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{GossipSync, Graph, P2PGossipSync, PeerManager, RapidGossipSync, UtxoLookup}; use crate::Error; @@ -15,13 +16,12 @@ use lightning_block_sync::gossip::{FutureSpawner, GossipVerifier}; use std::future::Future; use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::Arc; use std::time::Duration; pub(crate) enum GossipSource { P2PNetwork { gossip_sync: Arc, - logger: Arc, }, RapidGossipSync { gossip_sync: Arc, @@ -38,7 +38,7 @@ impl GossipSource { None::>, Arc::clone(&logger), )); - Self::P2PNetwork { gossip_sync, logger } + Self::P2PNetwork { gossip_sync } } pub fn new_rgs( @@ -63,12 +63,12 @@ impl GossipSource { pub(crate) fn set_gossip_verifier( &self, chain_source: Arc, peer_manager: Arc, - runtime: Arc>>>, + runtime: Arc, ) { match self { - Self::P2PNetwork { gossip_sync, logger } => { + Self::P2PNetwork { gossip_sync } => { if let Some(utxo_source) = chain_source.as_utxo_source() { - let spawner = RuntimeSpawner::new(Arc::clone(&runtime), Arc::clone(&logger)); + let spawner = RuntimeSpawner::new(Arc::clone(&runtime)); let gossip_verifier = Arc::new(GossipVerifier::new( utxo_source, spawner, @@ -133,28 +133,17 @@ impl GossipSource { } pub(crate) struct RuntimeSpawner { - runtime: Arc>>>, - logger: Arc, + runtime: Arc, } impl RuntimeSpawner { - pub(crate) fn new( - runtime: Arc>>>, logger: Arc, - ) -> Self { - Self { runtime, logger } + pub(crate) fn new(runtime: Arc) -> Self { + Self { runtime } } } impl FutureSpawner for RuntimeSpawner { fn spawn + Send + 'static>(&self, future: T) { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { - log_error!(self.logger, "Tried spawing a future while the runtime wasn't available. This should never happen."); - debug_assert!(false, "Tried spawing a future while the runtime wasn't available. This should never happen."); - return; - } - - let runtime = rt_lock.as_ref().unwrap(); - runtime.spawn(future); + self.runtime.spawn(future); } } diff --git a/src/lib.rs b/src/lib.rs index a3cce0752..cc5e383a1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -94,6 +94,7 @@ pub mod logger; mod message_handler; pub mod payment; mod peer_store; +mod runtime; mod sweep; mod tx_broadcaster; mod types; @@ -105,6 +106,7 @@ pub use lightning; pub use lightning_invoice; pub use lightning_liquidity; pub use lightning_types; +pub use tokio; pub use vss_client; pub use balance::{BalanceDetails, LightningBalance, PendingSweepBalance}; @@ -141,6 +143,7 @@ use payment::{ UnifiedQrPayment, }; use peer_store::{PeerInfo, PeerStore}; +use runtime::Runtime; use types::{ Broadcaster, BumpTransactionEventHandler, ChainMonitor, ChannelManager, DynStore, Graph, KeysManager, OnionMessenger, PaymentStore, PeerManager, Router, Scorer, Sweeper, Wallet, @@ -176,7 +179,7 @@ uniffi::include_scaffolding!("ldk_node"); /// /// Needs to be initialized and instantiated through [`Builder::build`]. pub struct Node { - runtime: Arc>>>, + runtime: Arc, stop_sender: tokio::sync::watch::Sender<()>, background_processor_task: Mutex>>, background_tasks: Mutex>>, @@ -202,6 +205,7 @@ pub struct Node { scorer: Arc>, peer_store: Arc>>, payment_store: Arc, + is_running: Arc>, is_listening: Arc, node_metrics: Arc>, } @@ -210,33 +214,21 @@ impl Node { /// Starts the necessary background tasks, such as handling events coming from user input, /// LDK/BDK, and the peer-to-peer network. /// - /// After this returns, the [`Node`] instance can be controlled via the provided API methods in - /// a thread-safe manner. - pub fn start(&self) -> Result<(), Error> { - let runtime = - Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()); - self.start_with_runtime(runtime) - } - - /// Starts the necessary background tasks (such as handling events coming from user input, - /// LDK/BDK, and the peer-to-peer network) on the the given `runtime`. - /// - /// This allows to have LDK Node reuse an outer pre-existing runtime, e.g., to avoid stacking Tokio - /// runtime contexts. + /// This will try to auto-detect an outer pre-existing runtime, e.g., to avoid stacking Tokio + /// runtime contexts. Note we require the outer runtime to be of the `multithreaded` flavor. /// /// After this returns, the [`Node`] instance can be controlled via the provided API methods in /// a thread-safe manner. - pub fn start_with_runtime(&self, runtime: Arc) -> Result<(), Error> { + pub fn start(&self) -> Result<(), Error> { // Acquire a run lock and hold it until we're setup. - let mut runtime_lock = self.runtime.write().unwrap(); - if runtime_lock.is_some() { - // We're already running. + let mut is_running_lock = self.is_running.write().unwrap(); + if *is_running_lock { return Err(Error::AlreadyRunning); } let mut background_tasks = tokio::task::JoinSet::new(); let mut cancellable_background_tasks = tokio::task::JoinSet::new(); - let runtime_handle = runtime.handle(); + let runtime_handle = self.runtime.handle(); log_info!( self.logger, @@ -246,17 +238,14 @@ impl Node { ); // Start up any runtime-dependant chain sources (e.g. Electrum) - self.chain_source.start(Arc::clone(&runtime)).map_err(|e| { + self.chain_source.start(Arc::clone(&self.runtime)).map_err(|e| { log_error!(self.logger, "Failed to start chain syncing: {}", e); e })?; // Block to ensure we update our fee rate cache once on startup let chain_source = Arc::clone(&self.chain_source); - let runtime_ref = &runtime; - tokio::task::block_in_place(move || { - runtime_ref.block_on(async move { chain_source.update_fee_rate_estimates().await }) - })?; + self.runtime.block_on(async move { chain_source.update_fee_rate_estimates().await })?; // Spawn background task continuously syncing onchain, lightning, and fee rate cache. let stop_sync_receiver = self.stop_sender.subscribe(); @@ -574,7 +563,7 @@ impl Node { }) }; - let handle = runtime.spawn(async move { + let handle = self.runtime.spawn(async move { process_events_async( background_persister, |e| background_event_handler.handle_event(e), @@ -621,8 +610,6 @@ impl Node { ); } - *runtime_lock = Some(runtime); - debug_assert!(self.background_tasks.lock().unwrap().is_none()); *self.background_tasks.lock().unwrap() = Some(background_tasks); @@ -630,6 +617,7 @@ impl Node { *self.cancellable_background_tasks.lock().unwrap() = Some(cancellable_background_tasks); log_info!(self.logger, "Startup complete."); + *is_running_lock = true; Ok(()) } @@ -637,9 +625,10 @@ impl Node { /// /// After this returns most API methods will return [`Error::NotRunning`]. pub fn stop(&self) -> Result<(), Error> { - let runtime = self.runtime.write().unwrap().take().ok_or(Error::NotRunning)?; - #[cfg(tokio_unstable)] - let metrics_runtime = Arc::clone(&runtime); + let mut is_running_lock = self.is_running.write().unwrap(); + if !*is_running_lock { + return Err(Error::NotRunning); + } log_info!(self.logger, "Shutting down LDK Node with node ID {}...", self.node_id()); @@ -661,10 +650,10 @@ impl Node { // Cancel cancellable background tasks if let Some(mut tasks) = self.cancellable_background_tasks.lock().unwrap().take() { - let runtime_2 = Arc::clone(&runtime); + let runtime_handle = self.runtime.handle(); tasks.abort_all(); tokio::task::block_in_place(move || { - runtime_2.block_on(async { while let Some(_) = tasks.join_next().await {} }) + runtime_handle.block_on(async { while let Some(_) = tasks.join_next().await {} }) }); } else { debug_assert!(false, "Expected some cancellable background tasks"); @@ -679,10 +668,10 @@ impl Node { log_debug!(self.logger, "Stopped chain sources."); // Wait until non-cancellable background tasks (mod LDK's background processor) are done. - let runtime_3 = Arc::clone(&runtime); + let runtime_handle = self.runtime.handle(); if let Some(mut tasks) = self.background_tasks.lock().unwrap().take() { tokio::task::block_in_place(move || { - runtime_3.block_on(async { + runtime_handle.block_on(async { loop { let timeout_fut = tokio::time::timeout( Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS), @@ -724,7 +713,7 @@ impl Node { { let abort_handle = background_processor_task.abort_handle(); let timeout_res = tokio::task::block_in_place(move || { - runtime.block_on(async { + self.runtime.block_on(async { tokio::time::timeout( Duration::from_secs(LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS), background_processor_task, @@ -757,20 +746,22 @@ impl Node { #[cfg(tokio_unstable)] { + let runtime_handle = self.runtime.handle(); log_trace!( self.logger, "Active runtime tasks left prior to shutdown: {}", - metrics_runtime.metrics().active_tasks_count() + runtime_handle.metrics().active_tasks_count() ); } log_info!(self.logger, "Shutdown complete."); + *is_running_lock = false; Ok(()) } /// Returns the status of the [`Node`]. pub fn status(&self) -> NodeStatus { - let is_running = self.runtime.read().unwrap().is_some(); + let is_running = *self.is_running.read().unwrap(); let is_listening = self.is_listening.load(Ordering::Acquire); let current_best_block = self.channel_manager.current_best_block().into(); let locked_node_metrics = self.node_metrics.read().unwrap(); @@ -891,6 +882,7 @@ impl Node { Arc::clone(&self.payment_store), Arc::clone(&self.peer_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -908,6 +900,7 @@ impl Node { Arc::clone(&self.payment_store), Arc::clone(&self.peer_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -918,9 +911,9 @@ impl Node { #[cfg(not(feature = "uniffi"))] pub fn bolt12_payment(&self) -> Bolt12Payment { Bolt12Payment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.payment_store), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -931,9 +924,9 @@ impl Node { #[cfg(feature = "uniffi")] pub fn bolt12_payment(&self) -> Arc { Arc::new(Bolt12Payment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.payment_store), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -942,11 +935,11 @@ impl Node { #[cfg(not(feature = "uniffi"))] pub fn spontaneous_payment(&self) -> SpontaneousPayment { SpontaneousPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.keys_manager), Arc::clone(&self.payment_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -955,11 +948,11 @@ impl Node { #[cfg(feature = "uniffi")] pub fn spontaneous_payment(&self) -> Arc { Arc::new(SpontaneousPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.channel_manager), Arc::clone(&self.keys_manager), Arc::clone(&self.payment_store), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -968,10 +961,10 @@ impl Node { #[cfg(not(feature = "uniffi"))] pub fn onchain_payment(&self) -> OnchainPayment { OnchainPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.wallet), Arc::clone(&self.channel_manager), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), ) } @@ -980,10 +973,10 @@ impl Node { #[cfg(feature = "uniffi")] pub fn onchain_payment(&self) -> Arc { Arc::new(OnchainPayment::new( - Arc::clone(&self.runtime), Arc::clone(&self.wallet), Arc::clone(&self.channel_manager), Arc::clone(&self.config), + Arc::clone(&self.is_running), Arc::clone(&self.logger), )) } @@ -1061,11 +1054,9 @@ impl Node { pub fn connect( &self, node_id: PublicKey, address: SocketAddress, persist: bool, ) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } - let runtime = rt_lock.as_ref().unwrap(); let peer_info = PeerInfo { node_id, address }; @@ -1075,10 +1066,8 @@ impl Node { // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; log_info!(self.logger, "Connected to peer {}@{}. ", peer_info.node_id, peer_info.address); @@ -1095,8 +1084,7 @@ impl Node { /// Will also remove the peer from the peer store, i.e., after this has been called we won't /// try to reconnect on restart. pub fn disconnect(&self, counterparty_node_id: PublicKey) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -1118,11 +1106,9 @@ impl Node { push_to_counterparty_msat: Option, channel_config: Option, announce_for_forwarding: bool, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } - let runtime = rt_lock.as_ref().unwrap(); let peer_info = PeerInfo { node_id, address }; @@ -1146,10 +1132,8 @@ impl Node { // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; // Fail if we have less than the channel value + anchor reserve available (if applicable). @@ -1298,8 +1282,7 @@ impl Node { /// /// [`EsploraSyncConfig::background_sync_config`]: crate::config::EsploraSyncConfig::background_sync_config pub fn sync_wallets(&self) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -1307,24 +1290,16 @@ impl Node { let sync_cman = Arc::clone(&self.channel_manager); let sync_cmon = Arc::clone(&self.chain_monitor); let sync_sweeper = Arc::clone(&self.output_sweeper); - tokio::task::block_in_place(move || { - tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap().block_on( - async move { - if chain_source.is_transaction_based() { - chain_source.update_fee_rate_estimates().await?; - chain_source - .sync_lightning_wallet(sync_cman, sync_cmon, sync_sweeper) - .await?; - chain_source.sync_onchain_wallet().await?; - } else { - chain_source.update_fee_rate_estimates().await?; - chain_source - .poll_and_update_listeners(sync_cman, sync_cmon, sync_sweeper) - .await?; - } - Ok(()) - }, - ) + self.runtime.block_on(async move { + if chain_source.is_transaction_based() { + chain_source.update_fee_rate_estimates().await?; + chain_source.sync_lightning_wallet(sync_cman, sync_cmon, sync_sweeper).await?; + chain_source.sync_onchain_wallet().await?; + } else { + chain_source.update_fee_rate_estimates().await?; + chain_source.poll_and_update_listeners(sync_cman, sync_cmon, sync_sweeper).await?; + } + Ok(()) }) } diff --git a/src/liquidity.rs b/src/liquidity.rs index a4516edd0..9b103ee82 100644 --- a/src/liquidity.rs +++ b/src/liquidity.rs @@ -10,6 +10,7 @@ use crate::chain::ChainSource; use crate::connection::ConnectionManager; use crate::logger::{log_debug, log_error, log_info, LdkLogger, Logger}; +use crate::runtime::Runtime; use crate::types::{ChannelManager, KeysManager, LiquidityManager, PeerManager, Wallet}; use crate::{total_anchor_channels_reserve_sats, Config, Error}; @@ -1388,7 +1389,7 @@ pub(crate) struct LSPS2BuyResponse { /// [`Bolt11Payment::receive_via_jit_channel`]: crate::payment::Bolt11Payment::receive_via_jit_channel #[derive(Clone)] pub struct LSPS1Liquidity { - runtime: Arc>>>, + runtime: Arc, wallet: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, @@ -1397,7 +1398,7 @@ pub struct LSPS1Liquidity { impl LSPS1Liquidity { pub(crate) fn new( - runtime: Arc>>>, wallet: Arc, + runtime: Arc, wallet: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, logger: Arc, ) -> Self { @@ -1418,19 +1419,14 @@ impl LSPS1Liquidity { let (lsp_node_id, lsp_address) = liquidity_source.get_lsps1_lsp_details().ok_or(Error::LiquiditySourceUnavailable)?; - let rt_lock = self.runtime.read().unwrap(); - let runtime = rt_lock.as_ref().unwrap(); - let con_node_id = lsp_node_id; let con_addr = lsp_address.clone(); let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; log_info!(self.logger, "Connected to LSP {}@{}. ", lsp_node_id, lsp_address); @@ -1438,18 +1434,16 @@ impl LSPS1Liquidity { let refund_address = self.wallet.get_new_address()?; let liquidity_source = Arc::clone(&liquidity_source); - let response = tokio::task::block_in_place(move || { - runtime.block_on(async move { - liquidity_source - .lsps1_request_channel( - lsp_balance_sat, - client_balance_sat, - channel_expiry_blocks, - announce_channel, - refund_address, - ) - .await - }) + let response = self.runtime.block_on(async move { + liquidity_source + .lsps1_request_channel( + lsp_balance_sat, + client_balance_sat, + channel_expiry_blocks, + announce_channel, + refund_address, + ) + .await })?; Ok(response) @@ -1463,27 +1457,20 @@ impl LSPS1Liquidity { let (lsp_node_id, lsp_address) = liquidity_source.get_lsps1_lsp_details().ok_or(Error::LiquiditySourceUnavailable)?; - let rt_lock = self.runtime.read().unwrap(); - let runtime = rt_lock.as_ref().unwrap(); - let con_node_id = lsp_node_id; let con_addr = lsp_address.clone(); let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; let liquidity_source = Arc::clone(&liquidity_source); - let response = tokio::task::block_in_place(move || { - runtime - .block_on(async move { liquidity_source.lsps1_check_order_status(order_id).await }) - })?; - + let response = self + .runtime + .block_on(async move { liquidity_source.lsps1_check_order_status(order_id).await })?; Ok(response) } } diff --git a/src/payment/bolt11.rs b/src/payment/bolt11.rs index 817a428bd..389c818c8 100644 --- a/src/payment/bolt11.rs +++ b/src/payment/bolt11.rs @@ -22,6 +22,7 @@ use crate::payment::store::{ }; use crate::payment::SendingParameters; use crate::peer_store::{PeerInfo, PeerStore}; +use crate::runtime::Runtime; use crate::types::{ChannelManager, PaymentStore}; use lightning::ln::bolt11_payment; @@ -57,24 +58,24 @@ type Bolt11InvoiceDescription = crate::ffi::Bolt11InvoiceDescription; /// [BOLT 11]: https://github.com/lightning/bolts/blob/master/11-payment-encoding.md /// [`Node::bolt11_payment`]: crate::Node::bolt11_payment pub struct Bolt11Payment { - runtime: Arc>>>, + runtime: Arc, channel_manager: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, payment_store: Arc, peer_store: Arc>>, config: Arc, + is_running: Arc>, logger: Arc, } impl Bolt11Payment { pub(crate) fn new( - runtime: Arc>>>, - channel_manager: Arc, + runtime: Arc, channel_manager: Arc, connection_manager: Arc>>, liquidity_source: Option>>>, payment_store: Arc, peer_store: Arc>>, - config: Arc, logger: Arc, + config: Arc, is_running: Arc>, logger: Arc, ) -> Self { Self { runtime, @@ -84,6 +85,7 @@ impl Bolt11Payment { payment_store, peer_store, config, + is_running, logger, } } @@ -95,12 +97,12 @@ impl Bolt11Payment { pub fn send( &self, invoice: &Bolt11Invoice, sending_parameters: Option, ) -> Result { - let invoice = maybe_deref(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let invoice = maybe_deref(invoice); + let (payment_hash, recipient_onion, mut route_params) = bolt11_payment::payment_parameters_from_invoice(&invoice).map_err(|_| { log_error!(self.logger, "Failed to send payment due to the given invoice being \"zero-amount\". Please use send_using_amount instead."); Error::InvalidInvoice @@ -204,12 +206,12 @@ impl Bolt11Payment { &self, invoice: &Bolt11Invoice, amount_msat: u64, sending_parameters: Option, ) -> Result { - let invoice = maybe_deref(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let invoice = maybe_deref(invoice); + if let Some(invoice_amount_msat) = invoice.amount_milli_satoshis() { if amount_msat < invoice_amount_msat { log_error!( @@ -619,9 +621,6 @@ impl Bolt11Payment { let (node_id, address) = liquidity_source.get_lsps2_lsp_details().ok_or(Error::LiquiditySourceUnavailable)?; - let rt_lock = self.runtime.read().unwrap(); - let runtime = rt_lock.as_ref().unwrap(); - let peer_info = PeerInfo { node_id, address }; let con_node_id = peer_info.node_id; @@ -630,39 +629,35 @@ impl Bolt11Payment { // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. - tokio::task::block_in_place(move || { - runtime.block_on(async move { - con_cm.connect_peer_if_necessary(con_node_id, con_addr).await - }) + self.runtime.block_on(async move { + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await })?; log_info!(self.logger, "Connected to LSP {}@{}. ", peer_info.node_id, peer_info.address); let liquidity_source = Arc::clone(&liquidity_source); let (invoice, lsp_total_opening_fee, lsp_prop_opening_fee) = - tokio::task::block_in_place(move || { - runtime.block_on(async move { - if let Some(amount_msat) = amount_msat { - liquidity_source - .lsps2_receive_to_jit_channel( - amount_msat, - description, - expiry_secs, - max_total_lsp_fee_limit_msat, - ) - .await - .map(|(invoice, total_fee)| (invoice, Some(total_fee), None)) - } else { - liquidity_source - .lsps2_receive_variable_amount_to_jit_channel( - description, - expiry_secs, - max_proportional_lsp_fee_limit_ppm_msat, - ) - .await - .map(|(invoice, prop_fee)| (invoice, None, Some(prop_fee))) - } - }) + self.runtime.block_on(async move { + if let Some(amount_msat) = amount_msat { + liquidity_source + .lsps2_receive_to_jit_channel( + amount_msat, + description, + expiry_secs, + max_total_lsp_fee_limit_msat, + ) + .await + .map(|(invoice, total_fee)| (invoice, Some(total_fee), None)) + } else { + liquidity_source + .lsps2_receive_variable_amount_to_jit_channel( + description, + expiry_secs, + max_proportional_lsp_fee_limit_ppm_msat, + ) + .await + .map(|(invoice, prop_fee)| (invoice, None, Some(prop_fee))) + } })?; // Register payment in payment store. @@ -712,12 +707,12 @@ impl Bolt11Payment { /// amount times [`Config::probing_liquidity_limit_multiplier`] won't be used to send /// pre-flight probes. pub fn send_probes(&self, invoice: &Bolt11Invoice) -> Result<(), Error> { - let invoice = maybe_deref(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let invoice = maybe_deref(invoice); + let (_payment_hash, _recipient_onion, route_params) = bolt11_payment::payment_parameters_from_invoice(&invoice).map_err(|_| { log_error!(self.logger, "Failed to send probes due to the given invoice being \"zero-amount\". Please use send_probes_using_amount instead."); Error::InvalidInvoice @@ -745,12 +740,12 @@ impl Bolt11Payment { pub fn send_probes_using_amount( &self, invoice: &Bolt11Invoice, amount_msat: u64, ) -> Result<(), Error> { - let invoice = maybe_deref(invoice); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let invoice = maybe_deref(invoice); + let (_payment_hash, _recipient_onion, route_params) = if let Some(invoice_amount_msat) = invoice.amount_milli_satoshis() { diff --git a/src/payment/bolt12.rs b/src/payment/bolt12.rs index b9efa3241..8e10b9f4f 100644 --- a/src/payment/bolt12.rs +++ b/src/payment/bolt12.rs @@ -49,19 +49,18 @@ type Refund = Arc; /// [BOLT 12]: https://github.com/lightning/bolts/blob/master/12-offer-encoding.md /// [`Node::bolt12_payment`]: crate::Node::bolt12_payment pub struct Bolt12Payment { - runtime: Arc>>>, channel_manager: Arc, payment_store: Arc, + is_running: Arc>, logger: Arc, } impl Bolt12Payment { pub(crate) fn new( - runtime: Arc>>>, channel_manager: Arc, payment_store: Arc, - logger: Arc, + is_running: Arc>, logger: Arc, ) -> Self { - Self { runtime, channel_manager, payment_store, logger } + Self { channel_manager, payment_store, is_running, logger } } /// Send a payment given an offer. @@ -73,11 +72,12 @@ impl Bolt12Payment { pub fn send( &self, offer: &Offer, quantity: Option, payer_note: Option, ) -> Result { - let offer = maybe_deref(offer); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + + let offer = maybe_deref(offer); + let mut random_bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut random_bytes); let payment_id = PaymentId(random_bytes); @@ -175,12 +175,12 @@ impl Bolt12Payment { pub fn send_using_amount( &self, offer: &Offer, amount_msat: u64, quantity: Option, payer_note: Option, ) -> Result { - let offer = maybe_deref(offer); - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } + let offer = maybe_deref(offer); + let mut random_bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut random_bytes); let payment_id = PaymentId(random_bytes); @@ -346,6 +346,10 @@ impl Bolt12Payment { /// [`Refund`]: lightning::offers::refund::Refund /// [`Bolt12Invoice`]: lightning::offers::invoice::Bolt12Invoice pub fn request_refund_payment(&self, refund: &Refund) -> Result { + if !*self.is_running.read().unwrap() { + return Err(Error::NotRunning); + } + let refund = maybe_deref(refund); let invoice = self.channel_manager.request_refund_payment(&refund).map_err(|e| { log_error!(self.logger, "Failed to request refund payment: {:?}", e); diff --git a/src/payment/onchain.rs b/src/payment/onchain.rs index 046d66c69..2614e55ce 100644 --- a/src/payment/onchain.rs +++ b/src/payment/onchain.rs @@ -41,19 +41,19 @@ macro_rules! maybe_map_fee_rate_opt { /// /// [`Node::onchain_payment`]: crate::Node::onchain_payment pub struct OnchainPayment { - runtime: Arc>>>, wallet: Arc, channel_manager: Arc, config: Arc, + is_running: Arc>, logger: Arc, } impl OnchainPayment { pub(crate) fn new( - runtime: Arc>>>, wallet: Arc, - channel_manager: Arc, config: Arc, logger: Arc, + wallet: Arc, channel_manager: Arc, config: Arc, + is_running: Arc>, logger: Arc, ) -> Self { - Self { runtime, wallet, channel_manager, config, logger } + Self { wallet, channel_manager, config, is_running, logger } } /// Retrieve a new on-chain/funding address. @@ -75,8 +75,7 @@ impl OnchainPayment { pub fn send_to_address( &self, address: &bitcoin::Address, amount_sats: u64, fee_rate: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -106,8 +105,7 @@ impl OnchainPayment { pub fn send_all_to_address( &self, address: &bitcoin::Address, retain_reserves: bool, fee_rate: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } diff --git a/src/payment/spontaneous.rs b/src/payment/spontaneous.rs index a7e7876d7..3e48fd090 100644 --- a/src/payment/spontaneous.rs +++ b/src/payment/spontaneous.rs @@ -33,21 +33,21 @@ const LDK_DEFAULT_FINAL_CLTV_EXPIRY_DELTA: u32 = 144; /// /// [`Node::spontaneous_payment`]: crate::Node::spontaneous_payment pub struct SpontaneousPayment { - runtime: Arc>>>, channel_manager: Arc, keys_manager: Arc, payment_store: Arc, config: Arc, + is_running: Arc>, logger: Arc, } impl SpontaneousPayment { pub(crate) fn new( - runtime: Arc>>>, channel_manager: Arc, keys_manager: Arc, - payment_store: Arc, config: Arc, logger: Arc, + payment_store: Arc, config: Arc, is_running: Arc>, + logger: Arc, ) -> Self { - Self { runtime, channel_manager, keys_manager, payment_store, config, logger } + Self { channel_manager, keys_manager, payment_store, config, is_running, logger } } /// Send a spontaneous aka. "keysend", payment. @@ -88,8 +88,7 @@ impl SpontaneousPayment { &self, amount_msat: u64, node_id: PublicKey, sending_parameters: Option, custom_tlvs: Option>, preimage: Option, ) -> Result { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } @@ -198,8 +197,7 @@ impl SpontaneousPayment { /// /// [`Bolt11Payment::send_probes`]: crate::payment::Bolt11Payment pub fn send_probes(&self, amount_msat: u64, node_id: PublicKey) -> Result<(), Error> { - let rt_lock = self.runtime.read().unwrap(); - if rt_lock.is_none() { + if !*self.is_running.read().unwrap() { return Err(Error::NotRunning); } diff --git a/src/runtime.rs b/src/runtime.rs new file mode 100644 index 000000000..4c1241165 --- /dev/null +++ b/src/runtime.rs @@ -0,0 +1,72 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use tokio::task::JoinHandle; + +use std::future::Future; + +pub(crate) struct Runtime { + mode: RuntimeMode, +} + +impl Runtime { + pub fn new() -> Result { + let mode = match tokio::runtime::Handle::try_current() { + Ok(handle) => RuntimeMode::Handle(handle), + Err(_) => { + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + RuntimeMode::Owned(rt) + }, + }; + Ok(Self { mode }) + } + + pub fn with_handle(handle: tokio::runtime::Handle) -> Self { + let mode = RuntimeMode::Handle(handle); + Self { mode } + } + + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let handle = self.handle(); + handle.spawn(future) + } + + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let handle = self.handle(); + handle.spawn_blocking(func) + } + + pub fn block_on(&self, future: F) -> F::Output { + // While we generally decided not to overthink via which call graph users would enter our + // runtime context, we'd still try to reuse whatever current context would be present + // during `block_on`, as this is the context `block_in_place` would operate on. So we try + // to detect the outer context here, and otherwise use whatever was set during + // initialization. + let handle = tokio::runtime::Handle::try_current().unwrap_or(self.handle().clone()); + tokio::task::block_in_place(move || handle.block_on(future)) + } + + pub fn handle(&self) -> &tokio::runtime::Handle { + match &self.mode { + RuntimeMode::Owned(rt) => rt.handle(), + RuntimeMode::Handle(handle) => handle, + } + } +} + +enum RuntimeMode { + Owned(tokio::runtime::Runtime), + Handle(tokio::runtime::Handle), +}