Skip to content

Commit 2bf7920

Browse files
committed
Wait on all background tasks to finish (or abort)
Previously, we'd only wait for the background processor tasks to successfully finish. It turned out that this could lead to races when the other background tasks took too long to shutdown. Here, we attempt to wait on all background tasks shutting down for a bit, before moving on.
1 parent 6b9d899 commit 2bf7920

File tree

3 files changed

+113
-33
lines changed

3 files changed

+113
-33
lines changed

src/builder.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,11 +1632,15 @@ fn build_with_store_internal(
16321632

16331633
let (stop_sender, _) = tokio::sync::watch::channel(());
16341634
let background_processor_task = Mutex::new(None);
1635+
let background_tasks = Mutex::new(None);
1636+
let cancellable_background_tasks = Mutex::new(None);
16351637

16361638
Ok(Node {
16371639
runtime,
16381640
stop_sender,
16391641
background_processor_task,
1642+
background_tasks,
1643+
cancellable_background_tasks,
16401644
config,
16411645
wallet,
16421646
chain_source,

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ pub(crate) const LDK_WALLET_SYNC_TIMEOUT_SECS: u64 = 10;
7979
// The timeout after which we give up waiting on LDK's event handler to exit on shutdown.
8080
pub(crate) const LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
8181

82+
// The timeout after which we give up waiting on a background task to exit on shutdown.
83+
pub(crate) const BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
84+
8285
// The timeout after which we abort a fee rate cache update operation.
8386
pub(crate) const FEE_RATE_CACHE_UPDATE_TIMEOUT_SECS: u64 = 5;
8487

src/lib.rs

Lines changed: 106 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ pub use builder::NodeBuilder as Builder;
127127
use chain::ChainSource;
128128
use config::{
129129
default_user_config, may_announce_channel, ChannelConfig, Config,
130-
LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL,
131-
RGS_SYNC_INTERVAL,
130+
BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS, LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS,
131+
NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL,
132132
};
133133
use connection::ConnectionManager;
134134
use event::{EventHandler, EventQueue};
@@ -179,6 +179,8 @@ pub struct Node {
179179
runtime: Arc<RwLock<Option<Arc<tokio::runtime::Runtime>>>>,
180180
stop_sender: tokio::sync::watch::Sender<()>,
181181
background_processor_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
182+
background_tasks: Mutex<Option<tokio::task::JoinSet<()>>>,
183+
cancellable_background_tasks: Mutex<Option<tokio::task::JoinSet<()>>>,
182184
config: Arc<Config>,
183185
wallet: Arc<Wallet>,
184186
chain_source: Arc<ChainSource>,
@@ -232,6 +234,10 @@ impl Node {
232234
return Err(Error::AlreadyRunning);
233235
}
234236

237+
let mut background_tasks = tokio::task::JoinSet::new();
238+
let mut cancellable_background_tasks = tokio::task::JoinSet::new();
239+
let runtime_handle = runtime.handle();
240+
235241
log_info!(
236242
self.logger,
237243
"Starting up LDK Node with node ID {} on network: {}",
@@ -258,19 +264,27 @@ impl Node {
258264
let sync_cman = Arc::clone(&self.channel_manager);
259265
let sync_cmon = Arc::clone(&self.chain_monitor);
260266
let sync_sweeper = Arc::clone(&self.output_sweeper);
261-
runtime.spawn(async move {
262-
chain_source
263-
.continuously_sync_wallets(stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
264-
.await;
265-
});
267+
background_tasks.spawn_on(
268+
async move {
269+
chain_source
270+
.continuously_sync_wallets(
271+
stop_sync_receiver,
272+
sync_cman,
273+
sync_cmon,
274+
sync_sweeper,
275+
)
276+
.await;
277+
},
278+
runtime_handle,
279+
);
266280

267281
if self.gossip_source.is_rgs() {
268282
let gossip_source = Arc::clone(&self.gossip_source);
269283
let gossip_sync_store = Arc::clone(&self.kv_store);
270284
let gossip_sync_logger = Arc::clone(&self.logger);
271285
let gossip_node_metrics = Arc::clone(&self.node_metrics);
272286
let mut stop_gossip_sync = self.stop_sender.subscribe();
273-
runtime.spawn(async move {
287+
background_tasks.spawn_on(async move {
274288
let mut interval = tokio::time::interval(RGS_SYNC_INTERVAL);
275289
loop {
276290
tokio::select! {
@@ -311,7 +325,7 @@ impl Node {
311325
}
312326
}
313327
}
314-
});
328+
}, runtime_handle);
315329
}
316330

317331
if let Some(listening_addresses) = &self.config.listening_addresses {
@@ -337,7 +351,7 @@ impl Node {
337351
bind_addrs.extend(resolved_address);
338352
}
339353

340-
runtime.spawn(async move {
354+
cancellable_background_tasks.spawn_on(async move {
341355
{
342356
let listener =
343357
tokio::net::TcpListener::bind(&*bind_addrs).await
@@ -356,7 +370,7 @@ impl Node {
356370
_ = stop_listen.changed() => {
357371
log_debug!(
358372
listening_logger,
359-
"Stopping listening to inbound connections.",
373+
"Stopping listening to inbound connections."
360374
);
361375
break;
362376
}
@@ -375,7 +389,7 @@ impl Node {
375389
}
376390

377391
listening_indicator.store(false, Ordering::Release);
378-
});
392+
}, runtime_handle);
379393
}
380394

381395
// Regularly reconnect to persisted peers.
@@ -384,15 +398,15 @@ impl Node {
384398
let connect_logger = Arc::clone(&self.logger);
385399
let connect_peer_store = Arc::clone(&self.peer_store);
386400
let mut stop_connect = self.stop_sender.subscribe();
387-
runtime.spawn(async move {
401+
cancellable_background_tasks.spawn_on(async move {
388402
let mut interval = tokio::time::interval(PEER_RECONNECTION_INTERVAL);
389403
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
390404
loop {
391405
tokio::select! {
392406
_ = stop_connect.changed() => {
393407
log_debug!(
394408
connect_logger,
395-
"Stopping reconnecting known peers.",
409+
"Stopping reconnecting known peers."
396410
);
397411
return;
398412
}
@@ -412,7 +426,7 @@ impl Node {
412426
}
413427
}
414428
}
415-
});
429+
}, runtime_handle);
416430

417431
// Regularly broadcast node announcements.
418432
let bcast_cm = Arc::clone(&self.channel_manager);
@@ -424,7 +438,7 @@ impl Node {
424438
let mut stop_bcast = self.stop_sender.subscribe();
425439
let node_alias = self.config.node_alias.clone();
426440
if may_announce_channel(&self.config).is_ok() {
427-
runtime.spawn(async move {
441+
background_tasks.spawn_on(async move {
428442
// We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
429443
#[cfg(not(test))]
430444
let mut interval = tokio::time::interval(Duration::from_secs(30));
@@ -495,14 +509,15 @@ impl Node {
495509
}
496510
}
497511
}
498-
});
512+
}, runtime_handle);
499513
}
500514

501515
let stop_tx_bcast = self.stop_sender.subscribe();
502516
let chain_source = Arc::clone(&self.chain_source);
503-
runtime.spawn(async move {
504-
chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await
505-
});
517+
cancellable_background_tasks.spawn_on(
518+
async move { chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await },
519+
runtime_handle,
520+
);
506521

507522
let bump_tx_event_handler = Arc::new(BumpTransactionEventHandler::new(
508523
Arc::clone(&self.tx_broadcaster),
@@ -587,24 +602,33 @@ impl Node {
587602
let mut stop_liquidity_handler = self.stop_sender.subscribe();
588603
let liquidity_handler = Arc::clone(&liquidity_source);
589604
let liquidity_logger = Arc::clone(&self.logger);
590-
runtime.spawn(async move {
591-
loop {
592-
tokio::select! {
593-
_ = stop_liquidity_handler.changed() => {
594-
log_debug!(
595-
liquidity_logger,
596-
"Stopping processing liquidity events.",
597-
);
598-
return;
605+
background_tasks.spawn_on(
606+
async move {
607+
loop {
608+
tokio::select! {
609+
_ = stop_liquidity_handler.changed() => {
610+
log_debug!(
611+
liquidity_logger,
612+
"Stopping processing liquidity events.",
613+
);
614+
return;
615+
}
616+
_ = liquidity_handler.handle_next_event() => {}
599617
}
600-
_ = liquidity_handler.handle_next_event() => {}
601618
}
602-
}
603-
});
619+
},
620+
runtime_handle,
621+
);
604622
}
605623

606624
*runtime_lock = Some(runtime);
607625

626+
debug_assert!(self.background_tasks.lock().unwrap().is_none());
627+
*self.background_tasks.lock().unwrap() = Some(background_tasks);
628+
629+
debug_assert!(self.cancellable_background_tasks.lock().unwrap().is_none());
630+
*self.cancellable_background_tasks.lock().unwrap() = Some(cancellable_background_tasks);
631+
608632
log_info!(self.logger, "Startup complete.");
609633
Ok(())
610634
}
@@ -643,6 +667,53 @@ impl Node {
643667
self.chain_source.stop();
644668
log_debug!(self.logger, "Stopped chain sources.");
645669

670+
// Cancel cancellable background tasks
671+
if let Some(mut tasks) = self.cancellable_background_tasks.lock().unwrap().take() {
672+
tasks.abort_all();
673+
} else {
674+
debug_assert!(false, "Expected some cancellable background tasks");
675+
};
676+
677+
// Wait until non-cancellable background tasks (mod LDK's background processor) are done.
678+
let runtime_2 = Arc::clone(&runtime);
679+
if let Some(mut tasks) = self.background_tasks.lock().unwrap().take() {
680+
tokio::task::block_in_place(move || {
681+
runtime_2.block_on(async {
682+
loop {
683+
let timeout_fut = tokio::time::timeout(
684+
Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS),
685+
tasks.join_next_with_id(),
686+
);
687+
match timeout_fut.await {
688+
Ok(Some(Ok((id, _)))) => {
689+
log_trace!(self.logger, "Stopped background task with id {}", id);
690+
},
691+
Ok(Some(Err(e))) => {
692+
tasks.abort_all();
693+
log_trace!(self.logger, "Stopping background task failed: {}", e);
694+
break;
695+
},
696+
Ok(None) => {
697+
log_debug!(self.logger, "Stopped all background tasks");
698+
break;
699+
},
700+
Err(e) => {
701+
tasks.abort_all();
702+
log_error!(
703+
self.logger,
704+
"Stopping background task timed out: {}",
705+
e
706+
);
707+
break;
708+
},
709+
}
710+
}
711+
})
712+
});
713+
} else {
714+
debug_assert!(false, "Expected some background tasks");
715+
};
716+
646717
// Wait until background processing stopped, at least until a timeout is reached.
647718
if let Some(background_processor_task) =
648719
self.background_processor_task.lock().unwrap().take()
@@ -676,7 +747,9 @@ impl Node {
676747
log_error!(self.logger, "Stopping event handling timed out: {}", e);
677748
},
678749
}
679-
}
750+
} else {
751+
debug_assert!(false, "Expected a background processing task");
752+
};
680753

681754
#[cfg(tokio_unstable)]
682755
{

0 commit comments

Comments
 (0)