Skip to content

Commit d4e7727

Browse files
committed
Wait on all background tasks to finish (or abort)
Previously, we'd only wait for the background processor tasks to successfully finish. It turned out that this could lead to races when the other background tasks took too long to shutdown. Here, we attempt to wait on all background tasks shutting down for a bit, before moving on.
1 parent 87a3ee6 commit d4e7727

File tree

3 files changed

+96
-33
lines changed

3 files changed

+96
-33
lines changed

src/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,13 +1667,15 @@ fn build_with_store_internal(
16671667

16681668
let (stop_sender, _) = tokio::sync::watch::channel(());
16691669
let background_processor_task = Mutex::new(None);
1670+
let background_tasks = Mutex::new(None);
16701671

16711672
let is_running = Arc::new(RwLock::new(false));
16721673

16731674
Ok(Node {
16741675
runtime,
16751676
stop_sender,
16761677
background_processor_task,
1678+
background_tasks,
16771679
config,
16781680
wallet,
16791681
chain_source,

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ pub(crate) const LDK_WALLET_SYNC_TIMEOUT_SECS: u64 = 10;
7979
// The timeout after which we give up waiting on LDK's event handler to exit on shutdown.
8080
pub(crate) const LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
8181

82+
// The timeout after which we give up waiting on a background task to exit on shutdown.
83+
pub(crate) const BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
84+
8285
// The timeout after which we abort a fee rate cache update operation.
8386
pub(crate) const FEE_RATE_CACHE_UPDATE_TIMEOUT_SECS: u64 = 5;
8487

src/lib.rs

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ pub use builder::NodeBuilder as Builder;
129129
use chain::ChainSource;
130130
use config::{
131131
default_user_config, may_announce_channel, ChannelConfig, Config,
132-
LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL,
133-
RGS_SYNC_INTERVAL,
132+
BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS, LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS,
133+
NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL,
134134
};
135135
use connection::ConnectionManager;
136136
use event::{EventHandler, EventQueue};
@@ -182,6 +182,7 @@ pub struct Node {
182182
runtime: Arc<Runtime>,
183183
stop_sender: tokio::sync::watch::Sender<()>,
184184
background_processor_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
185+
background_tasks: Mutex<Option<tokio::task::JoinSet<()>>>,
185186
config: Arc<Config>,
186187
wallet: Arc<Wallet>,
187188
chain_source: Arc<ChainSource>,
@@ -224,6 +225,9 @@ impl Node {
224225
return Err(Error::AlreadyRunning);
225226
}
226227

228+
let mut background_tasks = tokio::task::JoinSet::new();
229+
let runtime_handle = &self.runtime.handle();
230+
227231
log_info!(
228232
self.logger,
229233
"Starting up LDK Node with node ID {} on network: {}",
@@ -247,19 +251,27 @@ impl Node {
247251
let sync_cman = Arc::clone(&self.channel_manager);
248252
let sync_cmon = Arc::clone(&self.chain_monitor);
249253
let sync_sweeper = Arc::clone(&self.output_sweeper);
250-
self.runtime.spawn(async move {
251-
chain_source
252-
.continuously_sync_wallets(stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
253-
.await;
254-
});
254+
background_tasks.spawn_on(
255+
async move {
256+
chain_source
257+
.continuously_sync_wallets(
258+
stop_sync_receiver,
259+
sync_cman,
260+
sync_cmon,
261+
sync_sweeper,
262+
)
263+
.await;
264+
},
265+
runtime_handle,
266+
);
255267

256268
if self.gossip_source.is_rgs() {
257269
let gossip_source = Arc::clone(&self.gossip_source);
258270
let gossip_sync_store = Arc::clone(&self.kv_store);
259271
let gossip_sync_logger = Arc::clone(&self.logger);
260272
let gossip_node_metrics = Arc::clone(&self.node_metrics);
261273
let mut stop_gossip_sync = self.stop_sender.subscribe();
262-
self.runtime.spawn(async move {
274+
background_tasks.spawn_on(async move {
263275
let mut interval = tokio::time::interval(RGS_SYNC_INTERVAL);
264276
loop {
265277
tokio::select! {
@@ -300,7 +312,7 @@ impl Node {
300312
}
301313
}
302314
}
303-
});
315+
}, runtime_handle);
304316
}
305317

306318
if let Some(listening_addresses) = &self.config.listening_addresses {
@@ -326,7 +338,7 @@ impl Node {
326338
bind_addrs.extend(resolved_address);
327339
}
328340

329-
self.runtime.spawn(async move {
341+
background_tasks.spawn_on(async move {
330342
{
331343
let listener =
332344
tokio::net::TcpListener::bind(&*bind_addrs).await
@@ -345,7 +357,7 @@ impl Node {
345357
_ = stop_listen.changed() => {
346358
log_debug!(
347359
listening_logger,
348-
"Stopping listening to inbound connections.",
360+
"Stopping listening to inbound connections."
349361
);
350362
break;
351363
}
@@ -364,7 +376,7 @@ impl Node {
364376
}
365377

366378
listening_indicator.store(false, Ordering::Release);
367-
});
379+
}, runtime_handle);
368380
}
369381

370382
// Regularly reconnect to persisted peers.
@@ -373,15 +385,15 @@ impl Node {
373385
let connect_logger = Arc::clone(&self.logger);
374386
let connect_peer_store = Arc::clone(&self.peer_store);
375387
let mut stop_connect = self.stop_sender.subscribe();
376-
self.runtime.spawn(async move {
388+
background_tasks.spawn_on(async move {
377389
let mut interval = tokio::time::interval(PEER_RECONNECTION_INTERVAL);
378390
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
379391
loop {
380392
tokio::select! {
381393
_ = stop_connect.changed() => {
382394
log_debug!(
383395
connect_logger,
384-
"Stopping reconnecting known peers.",
396+
"Stopping reconnecting known peers."
385397
);
386398
return;
387399
}
@@ -401,7 +413,7 @@ impl Node {
401413
}
402414
}
403415
}
404-
});
416+
}, runtime_handle);
405417

406418
// Regularly broadcast node announcements.
407419
let bcast_cm = Arc::clone(&self.channel_manager);
@@ -413,7 +425,7 @@ impl Node {
413425
let mut stop_bcast = self.stop_sender.subscribe();
414426
let node_alias = self.config.node_alias.clone();
415427
if may_announce_channel(&self.config).is_ok() {
416-
self.runtime.spawn(async move {
428+
background_tasks.spawn_on(async move {
417429
// We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
418430
#[cfg(not(test))]
419431
let mut interval = tokio::time::interval(Duration::from_secs(30));
@@ -484,14 +496,15 @@ impl Node {
484496
}
485497
}
486498
}
487-
});
499+
}, runtime_handle);
488500
}
489501

490502
let stop_tx_bcast = self.stop_sender.subscribe();
491503
let chain_source = Arc::clone(&self.chain_source);
492-
self.runtime.spawn(async move {
493-
chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await
494-
});
504+
background_tasks.spawn_on(
505+
async move { chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await },
506+
runtime_handle,
507+
);
495508

496509
let bump_tx_event_handler = Arc::new(BumpTransactionEventHandler::new(
497510
Arc::clone(&self.tx_broadcaster),
@@ -576,22 +589,28 @@ impl Node {
576589
let mut stop_liquidity_handler = self.stop_sender.subscribe();
577590
let liquidity_handler = Arc::clone(&liquidity_source);
578591
let liquidity_logger = Arc::clone(&self.logger);
579-
self.runtime.spawn(async move {
580-
loop {
581-
tokio::select! {
582-
_ = stop_liquidity_handler.changed() => {
583-
log_debug!(
584-
liquidity_logger,
585-
"Stopping processing liquidity events.",
586-
);
587-
return;
592+
background_tasks.spawn_on(
593+
async move {
594+
loop {
595+
tokio::select! {
596+
_ = stop_liquidity_handler.changed() => {
597+
log_debug!(
598+
liquidity_logger,
599+
"Stopping processing liquidity events.",
600+
);
601+
return;
602+
}
603+
_ = liquidity_handler.handle_next_event() => {}
588604
}
589-
_ = liquidity_handler.handle_next_event() => {}
590605
}
591-
}
592-
});
606+
},
607+
runtime_handle,
608+
);
593609
}
594610

611+
debug_assert!(self.background_tasks.lock().unwrap().is_none());
612+
*self.background_tasks.lock().unwrap() = Some(background_tasks);
613+
595614
log_info!(self.logger, "Startup complete.");
596615
*is_running_lock = true;
597616
Ok(())
@@ -632,13 +651,52 @@ impl Node {
632651
self.chain_source.stop();
633652
log_debug!(self.logger, "Stopped chain sources.");
634653

654+
// Wait until all background tasks (mod LDK's background processor) are done.
655+
let runtime_handle = self.runtime.handle();
656+
if let Some(mut tasks) = self.background_tasks.lock().unwrap().take() {
657+
tokio::task::block_in_place(move || {
658+
runtime_handle.block_on(async {
659+
loop {
660+
let timeout_fut = tokio::time::timeout(
661+
Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS),
662+
tasks.join_next_with_id(),
663+
);
664+
match timeout_fut.await {
665+
Ok(Some(Ok((id, _)))) => {
666+
log_trace!(self.logger, "Stopped background task with id {}", id);
667+
},
668+
Ok(Some(Err(e))) => {
669+
tasks.abort_all();
670+
log_trace!(self.logger, "Stopping background task failed: {}", e);
671+
break;
672+
},
673+
Ok(None) => {
674+
log_debug!(self.logger, "Stopped all background tasks");
675+
break;
676+
},
677+
Err(e) => {
678+
tasks.abort_all();
679+
log_error!(
680+
self.logger,
681+
"Stopping background task timed out: {}",
682+
e
683+
);
684+
break;
685+
},
686+
}
687+
}
688+
})
689+
});
690+
}
691+
635692
// Wait until background processing stopped, at least until a timeout is reached.
636693
if let Some(background_processor_task) =
637694
self.background_processor_task.lock().unwrap().take()
638695
{
696+
let runtime_handle = self.runtime.handle();
639697
let abort_handle = background_processor_task.abort_handle();
640698
let timeout_res = tokio::task::block_in_place(move || {
641-
self.runtime.block_on(async {
699+
runtime_handle.block_on(async {
642700
tokio::time::timeout(
643701
Duration::from_secs(LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS),
644702
background_processor_task,

0 commit comments

Comments
 (0)