Skip to content

Commit a0d7ad4

Browse files
committed
Wait on all background tasks to finish (or abort)
Previously, we'd only wait for the background processor tasks to successfully finish. It turned out that this could lead to races when the other background tasks took too long to shutdown. Here, we attempt to wait on all background tasks shutting down for a bit, before moving on.
1 parent 8ce139a commit a0d7ad4

File tree

3 files changed

+113
-30
lines changed

3 files changed

+113
-30
lines changed

src/builder.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,11 +1496,15 @@ fn build_with_store_internal(
14961496

14971497
let (stop_sender, _) = tokio::sync::watch::channel(());
14981498
let background_processor_task = Mutex::new(None);
1499+
let background_tasks = Mutex::new(None);
1500+
let cancellable_background_tasks = Mutex::new(None);
14991501

15001502
Ok(Node {
15011503
runtime,
15021504
stop_sender,
15031505
background_processor_task,
1506+
background_tasks,
1507+
cancellable_background_tasks,
15041508
config,
15051509
wallet,
15061510
chain_source,

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ pub(crate) const LDK_WALLET_SYNC_TIMEOUT_SECS: u64 = 10;
7373
// The timeout after which we give up waiting on LDK's event handler to exit on shutdown.
7474
pub(crate) const LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
7575

76+
// The timeout after which we give up waiting on a background task to exit on shutdown.
77+
pub(crate) const BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
78+
7679
// The timeout after which we abort a fee rate cache update operation.
7780
pub(crate) const FEE_RATE_CACHE_UPDATE_TIMEOUT_SECS: u64 = 5;
7881

src/lib.rs

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ pub use builder::NodeBuilder as Builder;
128128
use chain::ChainSource;
129129
use config::{
130130
default_user_config, may_announce_channel, ChannelConfig, Config,
131-
LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL,
132-
RGS_SYNC_INTERVAL,
131+
BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS, LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS,
132+
NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL,
133133
};
134134
use connection::ConnectionManager;
135135
use event::{EventHandler, EventQueue};
@@ -180,6 +180,8 @@ pub struct Node {
180180
runtime: Arc<RwLock<Option<Arc<tokio::runtime::Runtime>>>>,
181181
stop_sender: tokio::sync::watch::Sender<()>,
182182
background_processor_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
183+
background_tasks: Mutex<Option<tokio::task::JoinSet<()>>>,
184+
cancellable_background_tasks: Mutex<Option<tokio::task::JoinSet<()>>>,
183185
config: Arc<Config>,
184186
wallet: Arc<Wallet>,
185187
chain_source: Arc<ChainSource>,
@@ -233,6 +235,10 @@ impl Node {
233235
return Err(Error::AlreadyRunning);
234236
}
235237

238+
let mut background_tasks = tokio::task::JoinSet::new();
239+
let mut cancellable_background_tasks = tokio::task::JoinSet::new();
240+
let runtime_handle = runtime.handle();
241+
236242
log_info!(
237243
self.logger,
238244
"Starting up LDK Node with node ID {} on network: {}",
@@ -259,19 +265,27 @@ impl Node {
259265
let sync_cman = Arc::clone(&self.channel_manager);
260266
let sync_cmon = Arc::clone(&self.chain_monitor);
261267
let sync_sweeper = Arc::clone(&self.output_sweeper);
262-
runtime.spawn(async move {
263-
chain_source
264-
.continuously_sync_wallets(stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
265-
.await;
266-
});
268+
background_tasks.spawn_on(
269+
async move {
270+
chain_source
271+
.continuously_sync_wallets(
272+
stop_sync_receiver,
273+
sync_cman,
274+
sync_cmon,
275+
sync_sweeper,
276+
)
277+
.await;
278+
},
279+
runtime_handle,
280+
);
267281

268282
if self.gossip_source.is_rgs() {
269283
let gossip_source = Arc::clone(&self.gossip_source);
270284
let gossip_sync_store = Arc::clone(&self.kv_store);
271285
let gossip_sync_logger = Arc::clone(&self.logger);
272286
let gossip_node_metrics = Arc::clone(&self.node_metrics);
273287
let mut stop_gossip_sync = self.stop_sender.subscribe();
274-
runtime.spawn(async move {
288+
cancellable_background_tasks.spawn_on(async move {
275289
let mut interval = tokio::time::interval(RGS_SYNC_INTERVAL);
276290
loop {
277291
tokio::select! {
@@ -312,7 +326,7 @@ impl Node {
312326
}
313327
}
314328
}
315-
});
329+
}, runtime_handle);
316330
}
317331

318332
if let Some(listening_addresses) = &self.config.listening_addresses {
@@ -338,7 +352,7 @@ impl Node {
338352
bind_addrs.extend(resolved_address);
339353
}
340354

341-
runtime.spawn(async move {
355+
cancellable_background_tasks.spawn_on(async move {
342356
{
343357
let listener =
344358
tokio::net::TcpListener::bind(&*bind_addrs).await
@@ -357,7 +371,7 @@ impl Node {
357371
_ = stop_listen.changed() => {
358372
log_debug!(
359373
listening_logger,
360-
"Stopping listening to inbound connections.",
374+
"Stopping listening to inbound connections."
361375
);
362376
break;
363377
}
@@ -376,7 +390,7 @@ impl Node {
376390
}
377391

378392
listening_indicator.store(false, Ordering::Release);
379-
});
393+
}, runtime_handle);
380394
}
381395

382396
// Regularly reconnect to persisted peers.
@@ -385,15 +399,15 @@ impl Node {
385399
let connect_logger = Arc::clone(&self.logger);
386400
let connect_peer_store = Arc::clone(&self.peer_store);
387401
let mut stop_connect = self.stop_sender.subscribe();
388-
runtime.spawn(async move {
402+
cancellable_background_tasks.spawn_on(async move {
389403
let mut interval = tokio::time::interval(PEER_RECONNECTION_INTERVAL);
390404
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
391405
loop {
392406
tokio::select! {
393407
_ = stop_connect.changed() => {
394408
log_debug!(
395409
connect_logger,
396-
"Stopping reconnecting known peers.",
410+
"Stopping reconnecting known peers."
397411
);
398412
return;
399413
}
@@ -413,7 +427,7 @@ impl Node {
413427
}
414428
}
415429
}
416-
});
430+
}, runtime_handle);
417431

418432
// Regularly broadcast node announcements.
419433
let bcast_cm = Arc::clone(&self.channel_manager);
@@ -425,7 +439,7 @@ impl Node {
425439
let mut stop_bcast = self.stop_sender.subscribe();
426440
let node_alias = self.config.node_alias.clone();
427441
if may_announce_channel(&self.config).is_ok() {
428-
runtime.spawn(async move {
442+
cancellable_background_tasks.spawn_on(async move {
429443
// We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
430444
#[cfg(not(test))]
431445
let mut interval = tokio::time::interval(Duration::from_secs(30));
@@ -496,7 +510,7 @@ impl Node {
496510
}
497511
}
498512
}
499-
});
513+
}, runtime_handle);
500514
}
501515

502516
let mut stop_tx_bcast = self.stop_sender.subscribe();
@@ -605,24 +619,33 @@ impl Node {
605619
let mut stop_liquidity_handler = self.stop_sender.subscribe();
606620
let liquidity_handler = Arc::clone(&liquidity_source);
607621
let liquidity_logger = Arc::clone(&self.logger);
608-
runtime.spawn(async move {
609-
loop {
610-
tokio::select! {
611-
_ = stop_liquidity_handler.changed() => {
612-
log_debug!(
613-
liquidity_logger,
614-
"Stopping processing liquidity events.",
615-
);
616-
return;
622+
background_tasks.spawn_on(
623+
async move {
624+
loop {
625+
tokio::select! {
626+
_ = stop_liquidity_handler.changed() => {
627+
log_debug!(
628+
liquidity_logger,
629+
"Stopping processing liquidity events.",
630+
);
631+
return;
632+
}
633+
_ = liquidity_handler.handle_next_event() => {}
617634
}
618-
_ = liquidity_handler.handle_next_event() => {}
619635
}
620-
}
621-
});
636+
},
637+
runtime_handle,
638+
);
622639
}
623640

624641
*runtime_lock = Some(runtime);
625642

643+
debug_assert!(self.background_tasks.lock().unwrap().is_none());
644+
*self.background_tasks.lock().unwrap() = Some(background_tasks);
645+
646+
debug_assert!(self.cancellable_background_tasks.lock().unwrap().is_none());
647+
*self.cancellable_background_tasks.lock().unwrap() = Some(cancellable_background_tasks);
648+
626649
log_info!(self.logger, "Startup complete.");
627650
Ok(())
628651
}
@@ -653,6 +676,17 @@ impl Node {
653676
},
654677
}
655678

679+
// Cancel cancellable background tasks
680+
if let Some(mut tasks) = self.cancellable_background_tasks.lock().unwrap().take() {
681+
let runtime_2 = Arc::clone(&runtime);
682+
tasks.abort_all();
683+
tokio::task::block_in_place(move || {
684+
runtime_2.block_on(async { while let Some(_) = tasks.join_next().await {} })
685+
});
686+
} else {
687+
debug_assert!(false, "Expected some cancellable background tasks");
688+
};
689+
656690
// Disconnect all peers.
657691
self.peer_manager.disconnect_all_peers();
658692
log_debug!(self.logger, "Disconnected all network peers.");
@@ -661,6 +695,46 @@ impl Node {
661695
self.chain_source.stop();
662696
log_debug!(self.logger, "Stopped chain sources.");
663697

698+
// Wait until non-cancellable background tasks (mod LDK's background processor) are done.
699+
let runtime_3 = Arc::clone(&runtime);
700+
if let Some(mut tasks) = self.background_tasks.lock().unwrap().take() {
701+
tokio::task::block_in_place(move || {
702+
runtime_3.block_on(async {
703+
loop {
704+
let timeout_fut = tokio::time::timeout(
705+
Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS),
706+
tasks.join_next_with_id(),
707+
);
708+
match timeout_fut.await {
709+
Ok(Some(Ok((id, _)))) => {
710+
log_trace!(self.logger, "Stopped background task with id {}", id);
711+
},
712+
Ok(Some(Err(e))) => {
713+
tasks.abort_all();
714+
log_trace!(self.logger, "Stopping background task failed: {}", e);
715+
break;
716+
},
717+
Ok(None) => {
718+
log_debug!(self.logger, "Stopped all background tasks");
719+
break;
720+
},
721+
Err(e) => {
722+
tasks.abort_all();
723+
log_error!(
724+
self.logger,
725+
"Stopping background task timed out: {}",
726+
e
727+
);
728+
break;
729+
},
730+
}
731+
}
732+
})
733+
});
734+
} else {
735+
debug_assert!(false, "Expected some background tasks");
736+
};
737+
664738
// Wait until background processing stopped, at least until a timeout is reached.
665739
if let Some(background_processor_task) =
666740
self.background_processor_task.lock().unwrap().take()
@@ -694,7 +768,9 @@ impl Node {
694768
log_error!(self.logger, "Stopping event handling timed out: {}", e);
695769
},
696770
}
697-
}
771+
} else {
772+
debug_assert!(false, "Expected a background processing task");
773+
};
698774

699775
#[cfg(tokio_unstable)]
700776
{

0 commit comments

Comments
 (0)