diff --git a/src/builder.rs b/src/builder.rs index 85ec70d18..5ead3783b 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1632,11 +1632,15 @@ fn build_with_store_internal( let (stop_sender, _) = tokio::sync::watch::channel(()); let background_processor_task = Mutex::new(None); + let background_tasks = Mutex::new(None); + let cancellable_background_tasks = Mutex::new(None); Ok(Node { runtime, stop_sender, background_processor_task, + background_tasks, + cancellable_background_tasks, config, wallet, chain_source, diff --git a/src/chain/bitcoind.rs b/src/chain/bitcoind.rs index d7d325460..a120f8253 100644 --- a/src/chain/bitcoind.rs +++ b/src/chain/bitcoind.rs @@ -16,7 +16,7 @@ use crate::fee_estimator::{ }; use crate::io::utils::write_node_metrics; use crate::logger::{log_bytes, log_error, log_info, log_trace, LdkLogger, Logger}; -use crate::types::{Broadcaster, ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; +use crate::types::{ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::{Error, NodeMetrics}; use lightning::chain::chaininterface::ConfirmationTarget as LdkConfirmationTarget; @@ -54,7 +54,6 @@ pub(super) struct BitcoindChainSource { onchain_wallet: Arc, wallet_polling_status: Mutex, fee_estimator: Arc, - tx_broadcaster: Arc, kv_store: Arc, config: Arc, logger: Arc, @@ -65,8 +64,8 @@ impl BitcoindChainSource { pub(crate) fn new_rpc( rpc_host: String, rpc_port: u16, rpc_user: String, rpc_password: String, onchain_wallet: Arc, fee_estimator: Arc, - tx_broadcaster: Arc, kv_store: Arc, config: Arc, - logger: Arc, node_metrics: Arc>, + kv_store: Arc, config: Arc, logger: Arc, + node_metrics: Arc>, ) -> Self { let api_client = Arc::new(BitcoindClient::new_rpc( rpc_host.clone(), @@ -85,7 +84,6 @@ impl BitcoindChainSource { onchain_wallet, wallet_polling_status, fee_estimator, - tx_broadcaster, kv_store, config, logger: Arc::clone(&logger), @@ -96,9 +94,8 @@ impl BitcoindChainSource { pub(crate) fn new_rest( rpc_host: String, rpc_port: u16, rpc_user: String, rpc_password: String, onchain_wallet: Arc, fee_estimator: Arc, - tx_broadcaster: Arc, kv_store: Arc, config: Arc, - rest_client_config: BitcoindRestClientConfig, logger: Arc, - node_metrics: Arc>, + kv_store: Arc, config: Arc, rest_client_config: BitcoindRestClientConfig, + logger: Arc, node_metrics: Arc>, ) -> Self { let api_client = Arc::new(BitcoindClient::new_rest( rest_client_config.rest_host, @@ -120,7 +117,6 @@ impl BitcoindChainSource { wallet_polling_status, onchain_wallet, fee_estimator, - tx_broadcaster, kv_store, config, logger: Arc::clone(&logger), @@ -530,53 +526,45 @@ impl BitcoindChainSource { Ok(()) } - pub(crate) async fn process_broadcast_queue(&self) { + pub(crate) async fn process_broadcast_package(&self, package: Vec) { // While it's a bit unclear when we'd be able to lean on Bitcoin Core >v28 // features, we should eventually switch to use `submitpackage` via the // `rust-bitcoind-json-rpc` crate rather than just broadcasting individual // transactions. - let mut receiver = self.tx_broadcaster.get_broadcast_queue().await; - while let Some(next_package) = receiver.recv().await { - for tx in &next_package { - let txid = tx.compute_txid(); - let timeout_fut = tokio::time::timeout( - Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), - self.api_client.broadcast_transaction(tx), - ); - match timeout_fut.await { - Ok(res) => match res { - Ok(id) => { - debug_assert_eq!(id, txid); - log_trace!(self.logger, "Successfully broadcast transaction {}", txid); - }, - Err(e) => { - log_error!( - self.logger, - "Failed to broadcast transaction {}: {}", - txid, - e - ); - log_trace!( - self.logger, - "Failed broadcast transaction bytes: {}", - log_bytes!(tx.encode()) - ); - }, + for tx in &package { + let txid = tx.compute_txid(); + let timeout_fut = tokio::time::timeout( + Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), + self.api_client.broadcast_transaction(tx), + ); + match timeout_fut.await { + Ok(res) => match res { + Ok(id) => { + debug_assert_eq!(id, txid); + log_trace!(self.logger, "Successfully broadcast transaction {}", txid); }, Err(e) => { - log_error!( - self.logger, - "Failed to broadcast transaction due to timeout {}: {}", - txid, - e - ); + log_error!(self.logger, "Failed to broadcast transaction {}: {}", txid, e); log_trace!( self.logger, "Failed broadcast transaction bytes: {}", log_bytes!(tx.encode()) ); }, - } + }, + Err(e) => { + log_error!( + self.logger, + "Failed to broadcast transaction due to timeout {}: {}", + txid, + e + ); + log_trace!( + self.logger, + "Failed broadcast transaction bytes: {}", + log_bytes!(tx.encode()) + ); + }, } } } diff --git a/src/chain/electrum.rs b/src/chain/electrum.rs index 6193c67b3..abbb758dd 100644 --- a/src/chain/electrum.rs +++ b/src/chain/electrum.rs @@ -18,7 +18,7 @@ use crate::fee_estimator::{ }; use crate::io::utils::write_node_metrics; use crate::logger::{log_bytes, log_error, log_info, log_trace, LdkLogger, Logger}; -use crate::types::{Broadcaster, ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; +use crate::types::{ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::NodeMetrics; use lightning::chain::{Confirm, Filter, WatchedOutput}; @@ -56,7 +56,6 @@ pub(super) struct ElectrumChainSource { onchain_wallet_sync_status: Mutex, lightning_wallet_sync_status: Mutex, fee_estimator: Arc, - tx_broadcaster: Arc, kv_store: Arc, config: Arc, logger: Arc, @@ -66,9 +65,8 @@ pub(super) struct ElectrumChainSource { impl ElectrumChainSource { pub(super) fn new( server_url: String, sync_config: ElectrumSyncConfig, onchain_wallet: Arc, - fee_estimator: Arc, tx_broadcaster: Arc, - kv_store: Arc, config: Arc, logger: Arc, - node_metrics: Arc>, + fee_estimator: Arc, kv_store: Arc, config: Arc, + logger: Arc, node_metrics: Arc>, ) -> Self { let electrum_runtime_status = RwLock::new(ElectrumRuntimeStatus::new()); let onchain_wallet_sync_status = Mutex::new(WalletSyncStatus::Completed); @@ -81,7 +79,6 @@ impl ElectrumChainSource { onchain_wallet_sync_status, lightning_wallet_sync_status, fee_estimator, - tx_broadcaster, kv_store, config, logger: Arc::clone(&logger), @@ -302,7 +299,7 @@ impl ElectrumChainSource { Ok(()) } - pub(crate) async fn process_broadcast_queue(&self) { + pub(crate) async fn process_broadcast_package(&self, package: Vec) { let electrum_client: Arc = if let Some(client) = self.electrum_runtime_status.read().unwrap().client().as_ref() { Arc::clone(client) @@ -311,11 +308,8 @@ impl ElectrumChainSource { return; }; - let mut receiver = self.tx_broadcaster.get_broadcast_queue().await; - while let Some(next_package) = receiver.recv().await { - for tx in next_package { - electrum_client.broadcast(tx).await; - } + for tx in package { + electrum_client.broadcast(tx).await; } } } diff --git a/src/chain/esplora.rs b/src/chain/esplora.rs index 5932426b7..a8806a413 100644 --- a/src/chain/esplora.rs +++ b/src/chain/esplora.rs @@ -18,7 +18,7 @@ use crate::fee_estimator::{ }; use crate::io::utils::write_node_metrics; use crate::logger::{log_bytes, log_error, log_info, log_trace, LdkLogger, Logger}; -use crate::types::{Broadcaster, ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; +use crate::types::{ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::{Error, NodeMetrics}; use lightning::chain::{Confirm, Filter, WatchedOutput}; @@ -30,7 +30,7 @@ use bdk_esplora::EsploraAsyncExt; use esplora_client::AsyncClient as EsploraAsyncClient; -use bitcoin::{FeeRate, Network, Script, Txid}; +use bitcoin::{FeeRate, Network, Script, Transaction, Txid}; use std::collections::HashMap; use std::sync::{Arc, Mutex, RwLock}; @@ -44,7 +44,6 @@ pub(super) struct EsploraChainSource { tx_sync: Arc>>, lightning_wallet_sync_status: Mutex, fee_estimator: Arc, - tx_broadcaster: Arc, kv_store: Arc, config: Arc, logger: Arc, @@ -55,8 +54,8 @@ impl EsploraChainSource { pub(crate) fn new( server_url: String, headers: HashMap, sync_config: EsploraSyncConfig, onchain_wallet: Arc, fee_estimator: Arc, - tx_broadcaster: Arc, kv_store: Arc, config: Arc, - logger: Arc, node_metrics: Arc>, + kv_store: Arc, config: Arc, logger: Arc, + node_metrics: Arc>, ) -> Self { // FIXME / TODO: We introduced this to make `bdk_esplora` work separately without updating // `lightning-transaction-sync`. We should revert this as part of of the upgrade to LDK 0.2. @@ -90,7 +89,6 @@ impl EsploraChainSource { tx_sync, lightning_wallet_sync_status, fee_estimator, - tx_broadcaster, kv_store, config, logger, @@ -372,76 +370,73 @@ impl EsploraChainSource { Ok(()) } - pub(crate) async fn process_broadcast_queue(&self) { - let mut receiver = self.tx_broadcaster.get_broadcast_queue().await; - while let Some(next_package) = receiver.recv().await { - for tx in &next_package { - let txid = tx.compute_txid(); - let timeout_fut = tokio::time::timeout( - Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), - self.esplora_client.broadcast(tx), - ); - match timeout_fut.await { - Ok(res) => match res { - Ok(()) => { - log_trace!(self.logger, "Successfully broadcast transaction {}", txid); - }, - Err(e) => match e { - esplora_client::Error::HttpResponse { status, message } => { - if status == 400 { - // Log 400 at lesser level, as this often just means bitcoind already knows the - // transaction. - // FIXME: We can further differentiate here based on the error - // message which will be available with rust-esplora-client 0.7 and - // later. - log_trace!( - self.logger, - "Failed to broadcast due to HTTP connection error: {}", - message - ); - } else { - log_error!( - self.logger, - "Failed to broadcast due to HTTP connection error: {} - {}", - status, - message - ); - } + pub(crate) async fn process_broadcast_package(&self, package: Vec) { + for tx in &package { + let txid = tx.compute_txid(); + let timeout_fut = tokio::time::timeout( + Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), + self.esplora_client.broadcast(tx), + ); + match timeout_fut.await { + Ok(res) => match res { + Ok(()) => { + log_trace!(self.logger, "Successfully broadcast transaction {}", txid); + }, + Err(e) => match e { + esplora_client::Error::HttpResponse { status, message } => { + if status == 400 { + // Log 400 at lesser level, as this often just means bitcoind already knows the + // transaction. + // FIXME: We can further differentiate here based on the error + // message which will be available with rust-esplora-client 0.7 and + // later. log_trace!( self.logger, - "Failed broadcast transaction bytes: {}", - log_bytes!(tx.encode()) + "Failed to broadcast due to HTTP connection error: {}", + message ); - }, - _ => { + } else { log_error!( self.logger, - "Failed to broadcast transaction {}: {}", - txid, - e - ); - log_trace!( - self.logger, - "Failed broadcast transaction bytes: {}", - log_bytes!(tx.encode()) + "Failed to broadcast due to HTTP connection error: {} - {}", + status, + message ); - }, + } + log_trace!( + self.logger, + "Failed broadcast transaction bytes: {}", + log_bytes!(tx.encode()) + ); + }, + _ => { + log_error!( + self.logger, + "Failed to broadcast transaction {}: {}", + txid, + e + ); + log_trace!( + self.logger, + "Failed broadcast transaction bytes: {}", + log_bytes!(tx.encode()) + ); }, }, - Err(e) => { - log_error!( - self.logger, - "Failed to broadcast transaction due to timeout {}: {}", - txid, - e - ); - log_trace!( - self.logger, - "Failed broadcast transaction bytes: {}", - log_bytes!(tx.encode()) - ); - }, - } + }, + Err(e) => { + log_error!( + self.logger, + "Failed to broadcast transaction due to timeout {}: {}", + txid, + e + ); + log_trace!( + self.logger, + "Failed broadcast transaction bytes: {}", + log_bytes!(tx.encode()) + ); + }, } } } diff --git a/src/chain/mod.rs b/src/chain/mod.rs index 91cce1fe3..a4ab2c76b 100644 --- a/src/chain/mod.rs +++ b/src/chain/mod.rs @@ -18,7 +18,7 @@ use crate::config::{ }; use crate::fee_estimator::OnchainFeeEstimator; use crate::io::utils::write_node_metrics; -use crate::logger::{log_info, log_trace, LdkLogger, Logger}; +use crate::logger::{log_debug, log_info, log_trace, LdkLogger, Logger}; use crate::types::{Broadcaster, ChainMonitor, ChannelManager, DynStore, Sweeper, Wallet}; use crate::{Error, NodeMetrics}; @@ -87,6 +87,7 @@ impl WalletSyncStatus { pub(crate) struct ChainSource { kind: ChainSourceKind, + tx_broadcaster: Arc, logger: Arc, } @@ -109,14 +110,13 @@ impl ChainSource { sync_config, onchain_wallet, fee_estimator, - tx_broadcaster, kv_store, config, Arc::clone(&logger), node_metrics, ); let kind = ChainSourceKind::Esplora(esplora_chain_source); - Self { kind, logger } + Self { kind, tx_broadcaster, logger } } pub(crate) fn new_electrum( @@ -130,14 +130,13 @@ impl ChainSource { sync_config, onchain_wallet, fee_estimator, - tx_broadcaster, kv_store, config, Arc::clone(&logger), node_metrics, ); let kind = ChainSourceKind::Electrum(electrum_chain_source); - Self { kind, logger } + Self { kind, tx_broadcaster, logger } } pub(crate) fn new_bitcoind_rpc( @@ -153,14 +152,13 @@ impl ChainSource { rpc_password, onchain_wallet, fee_estimator, - tx_broadcaster, kv_store, config, Arc::clone(&logger), node_metrics, ); let kind = ChainSourceKind::Bitcoind(bitcoind_chain_source); - Self { kind, logger } + Self { kind, tx_broadcaster, logger } } pub(crate) fn new_bitcoind_rest( @@ -177,7 +175,6 @@ impl ChainSource { rpc_password, onchain_wallet, fee_estimator, - tx_broadcaster, kv_store, config, rest_client_config, @@ -185,7 +182,7 @@ impl ChainSource { node_metrics, ); let kind = ChainSourceKind::Bitcoind(bitcoind_chain_source); - Self { kind, logger } + Self { kind, tx_broadcaster, logger } } pub(crate) fn start(&self, runtime: Arc) -> Result<(), Error> { @@ -428,17 +425,34 @@ impl ChainSource { } } - pub(crate) async fn process_broadcast_queue(&self) { - match &self.kind { - ChainSourceKind::Esplora(esplora_chain_source) => { - esplora_chain_source.process_broadcast_queue().await - }, - ChainSourceKind::Electrum(electrum_chain_source) => { - electrum_chain_source.process_broadcast_queue().await - }, - ChainSourceKind::Bitcoind(bitcoind_chain_source) => { - bitcoind_chain_source.process_broadcast_queue().await - }, + pub(crate) async fn continuously_process_broadcast_queue( + &self, mut stop_tx_bcast_receiver: tokio::sync::watch::Receiver<()>, + ) { + let mut receiver = self.tx_broadcaster.get_broadcast_queue().await; + loop { + let tx_bcast_logger = Arc::clone(&self.logger); + tokio::select! { + _ = stop_tx_bcast_receiver.changed() => { + log_debug!( + tx_bcast_logger, + "Stopping broadcasting transactions.", + ); + return; + } + Some(next_package) = receiver.recv() => { + match &self.kind { + ChainSourceKind::Esplora(esplora_chain_source) => { + esplora_chain_source.process_broadcast_package(next_package).await + }, + ChainSourceKind::Electrum(electrum_chain_source) => { + electrum_chain_source.process_broadcast_package(next_package).await + }, + ChainSourceKind::Bitcoind(bitcoind_chain_source) => { + bitcoind_chain_source.process_broadcast_package(next_package).await + }, + } + } + } } } } diff --git a/src/config.rs b/src/config.rs index a5048e64f..02df8bbc7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -79,6 +79,9 @@ pub(crate) const LDK_WALLET_SYNC_TIMEOUT_SECS: u64 = 10; // The timeout after which we give up waiting on LDK's event handler to exit on shutdown. pub(crate) const LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS: u64 = 30; +// The timeout after which we give up waiting on a background task to exit on shutdown. +pub(crate) const BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS: u64 = 5; + // The timeout after which we abort a fee rate cache update operation. pub(crate) const FEE_RATE_CACHE_UPDATE_TIMEOUT_SECS: u64 = 5; diff --git a/src/lib.rs b/src/lib.rs index 89a17ab03..a3cce0752 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -127,8 +127,8 @@ pub use builder::NodeBuilder as Builder; use chain::ChainSource; use config::{ default_user_config, may_announce_channel, ChannelConfig, Config, - LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, - RGS_SYNC_INTERVAL, + BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS, LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS, + NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL, }; use connection::ConnectionManager; use event::{EventHandler, EventQueue}; @@ -179,6 +179,8 @@ pub struct Node { runtime: Arc>>>, stop_sender: tokio::sync::watch::Sender<()>, background_processor_task: Mutex>>, + background_tasks: Mutex>>, + cancellable_background_tasks: Mutex>>, config: Arc, wallet: Arc, chain_source: Arc, @@ -232,6 +234,10 @@ impl Node { return Err(Error::AlreadyRunning); } + let mut background_tasks = tokio::task::JoinSet::new(); + let mut cancellable_background_tasks = tokio::task::JoinSet::new(); + let runtime_handle = runtime.handle(); + log_info!( self.logger, "Starting up LDK Node with node ID {} on network: {}", @@ -258,11 +264,19 @@ impl Node { let sync_cman = Arc::clone(&self.channel_manager); let sync_cmon = Arc::clone(&self.chain_monitor); let sync_sweeper = Arc::clone(&self.output_sweeper); - runtime.spawn(async move { - chain_source - .continuously_sync_wallets(stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper) - .await; - }); + background_tasks.spawn_on( + async move { + chain_source + .continuously_sync_wallets( + stop_sync_receiver, + sync_cman, + sync_cmon, + sync_sweeper, + ) + .await; + }, + runtime_handle, + ); if self.gossip_source.is_rgs() { let gossip_source = Arc::clone(&self.gossip_source); @@ -270,7 +284,7 @@ impl Node { let gossip_sync_logger = Arc::clone(&self.logger); let gossip_node_metrics = Arc::clone(&self.node_metrics); let mut stop_gossip_sync = self.stop_sender.subscribe(); - runtime.spawn(async move { + cancellable_background_tasks.spawn_on(async move { let mut interval = tokio::time::interval(RGS_SYNC_INTERVAL); loop { tokio::select! { @@ -311,7 +325,7 @@ impl Node { } } } - }); + }, runtime_handle); } if let Some(listening_addresses) = &self.config.listening_addresses { @@ -337,7 +351,7 @@ impl Node { bind_addrs.extend(resolved_address); } - runtime.spawn(async move { + cancellable_background_tasks.spawn_on(async move { { let listener = tokio::net::TcpListener::bind(&*bind_addrs).await @@ -356,7 +370,7 @@ impl Node { _ = stop_listen.changed() => { log_debug!( listening_logger, - "Stopping listening to inbound connections.", + "Stopping listening to inbound connections." ); break; } @@ -375,7 +389,7 @@ impl Node { } listening_indicator.store(false, Ordering::Release); - }); + }, runtime_handle); } // Regularly reconnect to persisted peers. @@ -384,7 +398,7 @@ impl Node { let connect_logger = Arc::clone(&self.logger); let connect_peer_store = Arc::clone(&self.peer_store); let mut stop_connect = self.stop_sender.subscribe(); - runtime.spawn(async move { + cancellable_background_tasks.spawn_on(async move { let mut interval = tokio::time::interval(PEER_RECONNECTION_INTERVAL); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { @@ -392,7 +406,7 @@ impl Node { _ = stop_connect.changed() => { log_debug!( connect_logger, - "Stopping reconnecting known peers.", + "Stopping reconnecting known peers." ); return; } @@ -412,7 +426,7 @@ impl Node { } } } - }); + }, runtime_handle); // Regularly broadcast node announcements. let bcast_cm = Arc::clone(&self.channel_manager); @@ -424,7 +438,7 @@ impl Node { let mut stop_bcast = self.stop_sender.subscribe(); let node_alias = self.config.node_alias.clone(); if may_announce_channel(&self.config).is_ok() { - runtime.spawn(async move { + cancellable_background_tasks.spawn_on(async move { // We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away. #[cfg(not(test))] let mut interval = tokio::time::interval(Duration::from_secs(30)); @@ -495,31 +509,15 @@ impl Node { } } } - }); + }, runtime_handle); } - let mut stop_tx_bcast = self.stop_sender.subscribe(); + let stop_tx_bcast = self.stop_sender.subscribe(); let chain_source = Arc::clone(&self.chain_source); - let tx_bcast_logger = Arc::clone(&self.logger); - runtime.spawn(async move { - // Every second we try to clear our broadcasting queue. - let mut interval = tokio::time::interval(Duration::from_secs(1)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - loop { - tokio::select! { - _ = stop_tx_bcast.changed() => { - log_debug!( - tx_bcast_logger, - "Stopping broadcasting transactions.", - ); - return; - } - _ = interval.tick() => { - chain_source.process_broadcast_queue().await; - } - } - } - }); + cancellable_background_tasks.spawn_on( + async move { chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await }, + runtime_handle, + ); let bump_tx_event_handler = Arc::new(BumpTransactionEventHandler::new( Arc::clone(&self.tx_broadcaster), @@ -604,24 +602,33 @@ impl Node { let mut stop_liquidity_handler = self.stop_sender.subscribe(); let liquidity_handler = Arc::clone(&liquidity_source); let liquidity_logger = Arc::clone(&self.logger); - runtime.spawn(async move { - loop { - tokio::select! { - _ = stop_liquidity_handler.changed() => { - log_debug!( - liquidity_logger, - "Stopping processing liquidity events.", - ); - return; + background_tasks.spawn_on( + async move { + loop { + tokio::select! { + _ = stop_liquidity_handler.changed() => { + log_debug!( + liquidity_logger, + "Stopping processing liquidity events.", + ); + return; + } + _ = liquidity_handler.handle_next_event() => {} } - _ = liquidity_handler.handle_next_event() => {} } - } - }); + }, + runtime_handle, + ); } *runtime_lock = Some(runtime); + debug_assert!(self.background_tasks.lock().unwrap().is_none()); + *self.background_tasks.lock().unwrap() = Some(background_tasks); + + debug_assert!(self.cancellable_background_tasks.lock().unwrap().is_none()); + *self.cancellable_background_tasks.lock().unwrap() = Some(cancellable_background_tasks); + log_info!(self.logger, "Startup complete."); Ok(()) } @@ -652,6 +659,17 @@ impl Node { }, } + // Cancel cancellable background tasks + if let Some(mut tasks) = self.cancellable_background_tasks.lock().unwrap().take() { + let runtime_2 = Arc::clone(&runtime); + tasks.abort_all(); + tokio::task::block_in_place(move || { + runtime_2.block_on(async { while let Some(_) = tasks.join_next().await {} }) + }); + } else { + debug_assert!(false, "Expected some cancellable background tasks"); + }; + // Disconnect all peers. self.peer_manager.disconnect_all_peers(); log_debug!(self.logger, "Disconnected all network peers."); @@ -660,6 +678,46 @@ impl Node { self.chain_source.stop(); log_debug!(self.logger, "Stopped chain sources."); + // Wait until non-cancellable background tasks (mod LDK's background processor) are done. + let runtime_3 = Arc::clone(&runtime); + if let Some(mut tasks) = self.background_tasks.lock().unwrap().take() { + tokio::task::block_in_place(move || { + runtime_3.block_on(async { + loop { + let timeout_fut = tokio::time::timeout( + Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS), + tasks.join_next_with_id(), + ); + match timeout_fut.await { + Ok(Some(Ok((id, _)))) => { + log_trace!(self.logger, "Stopped background task with id {}", id); + }, + Ok(Some(Err(e))) => { + tasks.abort_all(); + log_trace!(self.logger, "Stopping background task failed: {}", e); + break; + }, + Ok(None) => { + log_debug!(self.logger, "Stopped all background tasks"); + break; + }, + Err(e) => { + tasks.abort_all(); + log_error!( + self.logger, + "Stopping background task timed out: {}", + e + ); + break; + }, + } + } + }) + }); + } else { + debug_assert!(false, "Expected some background tasks"); + }; + // Wait until background processing stopped, at least until a timeout is reached. if let Some(background_processor_task) = self.background_processor_task.lock().unwrap().take() @@ -693,7 +751,9 @@ impl Node { log_error!(self.logger, "Stopping event handling timed out: {}", e); }, } - } + } else { + debug_assert!(false, "Expected a background processing task"); + }; #[cfg(tokio_unstable)] { diff --git a/tests/integration_tests_rust.rs b/tests/integration_tests_rust.rs index 57742e09e..ad3867429 100644 --- a/tests/integration_tests_rust.rs +++ b/tests/integration_tests_rust.rs @@ -1457,3 +1457,14 @@ fn spontaneous_send_with_custom_preimage() { panic!("Expected receiver to have spontaneous PaymentKind with preimage"); } } + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn drop_in_async_context() { + let (_bitcoind, electrsd) = setup_bitcoind_and_electrsd(); + let chain_source = TestChainSource::Esplora(&electrsd); + let seed_bytes = vec![42u8; 64]; + + let config = random_config(true); + let node = setup_node(&chain_source, config, Some(seed_bytes)); + node.stop().unwrap(); +}