Skip to content

Commit 7ec7f29

Browse files
committed
Wait on all background tasks to finish (or abort)
Previously, we'd only wait for the background processor tasks to successfully finish. It turned out that this could lead to races when the other background tasks took too long to shutdown. Here, we attempt to wait on all background tasks shutting down for a bit, before moving on.
1 parent 6b9d899 commit 7ec7f29

File tree

3 files changed

+94
-32
lines changed

3 files changed

+94
-32
lines changed

src/builder.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,11 +1632,13 @@ fn build_with_store_internal(
16321632

16331633
let (stop_sender, _) = tokio::sync::watch::channel(());
16341634
let background_processor_task = Mutex::new(None);
1635+
let background_tasks = Mutex::new(None);
16351636

16361637
Ok(Node {
16371638
runtime,
16381639
stop_sender,
16391640
background_processor_task,
1641+
background_tasks,
16401642
config,
16411643
wallet,
16421644
chain_source,

src/config.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ pub(crate) const LDK_WALLET_SYNC_TIMEOUT_SECS: u64 = 10;
7979
// The timeout after which we give up waiting on LDK's event handler to exit on shutdown.
8080
pub(crate) const LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
8181

82+
// The timeout after which we give up waiting on a background task to exit on shutdown.
83+
pub(crate) const BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
84+
8285
// The timeout after which we abort a fee rate cache update operation.
8386
pub(crate) const FEE_RATE_CACHE_UPDATE_TIMEOUT_SECS: u64 = 5;
8487

src/lib.rs

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ pub use builder::NodeBuilder as Builder;
127127
use chain::ChainSource;
128128
use config::{
129129
default_user_config, may_announce_channel, ChannelConfig, Config,
130-
LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL,
131-
RGS_SYNC_INTERVAL,
130+
BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS, LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS,
131+
NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL,
132132
};
133133
use connection::ConnectionManager;
134134
use event::{EventHandler, EventQueue};
@@ -179,6 +179,7 @@ pub struct Node {
179179
runtime: Arc<RwLock<Option<Arc<tokio::runtime::Runtime>>>>,
180180
stop_sender: tokio::sync::watch::Sender<()>,
181181
background_processor_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
182+
background_tasks: Mutex<Option<tokio::task::JoinSet<()>>>,
182183
config: Arc<Config>,
183184
wallet: Arc<Wallet>,
184185
chain_source: Arc<ChainSource>,
@@ -232,6 +233,9 @@ impl Node {
232233
return Err(Error::AlreadyRunning);
233234
}
234235

236+
let mut background_tasks = tokio::task::JoinSet::new();
237+
let runtime_handle = runtime.handle();
238+
235239
log_info!(
236240
self.logger,
237241
"Starting up LDK Node with node ID {} on network: {}",
@@ -258,19 +262,27 @@ impl Node {
258262
let sync_cman = Arc::clone(&self.channel_manager);
259263
let sync_cmon = Arc::clone(&self.chain_monitor);
260264
let sync_sweeper = Arc::clone(&self.output_sweeper);
261-
runtime.spawn(async move {
262-
chain_source
263-
.continuously_sync_wallets(stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
264-
.await;
265-
});
265+
background_tasks.spawn_on(
266+
async move {
267+
chain_source
268+
.continuously_sync_wallets(
269+
stop_sync_receiver,
270+
sync_cman,
271+
sync_cmon,
272+
sync_sweeper,
273+
)
274+
.await;
275+
},
276+
runtime_handle,
277+
);
266278

267279
if self.gossip_source.is_rgs() {
268280
let gossip_source = Arc::clone(&self.gossip_source);
269281
let gossip_sync_store = Arc::clone(&self.kv_store);
270282
let gossip_sync_logger = Arc::clone(&self.logger);
271283
let gossip_node_metrics = Arc::clone(&self.node_metrics);
272284
let mut stop_gossip_sync = self.stop_sender.subscribe();
273-
runtime.spawn(async move {
285+
background_tasks.spawn_on(async move {
274286
let mut interval = tokio::time::interval(RGS_SYNC_INTERVAL);
275287
loop {
276288
tokio::select! {
@@ -311,7 +323,7 @@ impl Node {
311323
}
312324
}
313325
}
314-
});
326+
}, runtime_handle);
315327
}
316328

317329
if let Some(listening_addresses) = &self.config.listening_addresses {
@@ -337,7 +349,7 @@ impl Node {
337349
bind_addrs.extend(resolved_address);
338350
}
339351

340-
runtime.spawn(async move {
352+
background_tasks.spawn_on(async move {
341353
{
342354
let listener =
343355
tokio::net::TcpListener::bind(&*bind_addrs).await
@@ -356,7 +368,7 @@ impl Node {
356368
_ = stop_listen.changed() => {
357369
log_debug!(
358370
listening_logger,
359-
"Stopping listening to inbound connections.",
371+
"Stopping listening to inbound connections."
360372
);
361373
break;
362374
}
@@ -375,7 +387,7 @@ impl Node {
375387
}
376388

377389
listening_indicator.store(false, Ordering::Release);
378-
});
390+
}, runtime_handle);
379391
}
380392

381393
// Regularly reconnect to persisted peers.
@@ -384,15 +396,15 @@ impl Node {
384396
let connect_logger = Arc::clone(&self.logger);
385397
let connect_peer_store = Arc::clone(&self.peer_store);
386398
let mut stop_connect = self.stop_sender.subscribe();
387-
runtime.spawn(async move {
399+
background_tasks.spawn_on(async move {
388400
let mut interval = tokio::time::interval(PEER_RECONNECTION_INTERVAL);
389401
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
390402
loop {
391403
tokio::select! {
392404
_ = stop_connect.changed() => {
393405
log_debug!(
394406
connect_logger,
395-
"Stopping reconnecting known peers.",
407+
"Stopping reconnecting known peers."
396408
);
397409
return;
398410
}
@@ -412,7 +424,7 @@ impl Node {
412424
}
413425
}
414426
}
415-
});
427+
}, runtime_handle);
416428

417429
// Regularly broadcast node announcements.
418430
let bcast_cm = Arc::clone(&self.channel_manager);
@@ -424,7 +436,7 @@ impl Node {
424436
let mut stop_bcast = self.stop_sender.subscribe();
425437
let node_alias = self.config.node_alias.clone();
426438
if may_announce_channel(&self.config).is_ok() {
427-
runtime.spawn(async move {
439+
background_tasks.spawn_on(async move {
428440
// We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
429441
#[cfg(not(test))]
430442
let mut interval = tokio::time::interval(Duration::from_secs(30));
@@ -495,14 +507,15 @@ impl Node {
495507
}
496508
}
497509
}
498-
});
510+
}, runtime_handle);
499511
}
500512

501513
let stop_tx_bcast = self.stop_sender.subscribe();
502514
let chain_source = Arc::clone(&self.chain_source);
503-
runtime.spawn(async move {
504-
chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await
505-
});
515+
background_tasks.spawn_on(
516+
async move { chain_source.continuously_process_broadcast_queue(stop_tx_bcast).await },
517+
runtime_handle,
518+
);
506519

507520
let bump_tx_event_handler = Arc::new(BumpTransactionEventHandler::new(
508521
Arc::clone(&self.tx_broadcaster),
@@ -587,24 +600,30 @@ impl Node {
587600
let mut stop_liquidity_handler = self.stop_sender.subscribe();
588601
let liquidity_handler = Arc::clone(&liquidity_source);
589602
let liquidity_logger = Arc::clone(&self.logger);
590-
runtime.spawn(async move {
591-
loop {
592-
tokio::select! {
593-
_ = stop_liquidity_handler.changed() => {
594-
log_debug!(
595-
liquidity_logger,
596-
"Stopping processing liquidity events.",
597-
);
598-
return;
603+
background_tasks.spawn_on(
604+
async move {
605+
loop {
606+
tokio::select! {
607+
_ = stop_liquidity_handler.changed() => {
608+
log_debug!(
609+
liquidity_logger,
610+
"Stopping processing liquidity events.",
611+
);
612+
return;
613+
}
614+
_ = liquidity_handler.handle_next_event() => {}
599615
}
600-
_ = liquidity_handler.handle_next_event() => {}
601616
}
602-
}
603-
});
617+
},
618+
runtime_handle,
619+
);
604620
}
605621

606622
*runtime_lock = Some(runtime);
607623

624+
debug_assert!(self.background_tasks.lock().unwrap().is_none());
625+
*self.background_tasks.lock().unwrap() = Some(background_tasks);
626+
608627
log_info!(self.logger, "Startup complete.");
609628
Ok(())
610629
}
@@ -643,6 +662,44 @@ impl Node {
643662
self.chain_source.stop();
644663
log_debug!(self.logger, "Stopped chain sources.");
645664

665+
// Wait until all background tasks (mod LDK's background processor) are done.
666+
let runtime_2 = Arc::clone(&runtime);
667+
if let Some(mut tasks) = self.background_tasks.lock().unwrap().take() {
668+
tokio::task::block_in_place(move || {
669+
runtime_2.block_on(async {
670+
loop {
671+
let timeout_fut = tokio::time::timeout(
672+
Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS),
673+
tasks.join_next_with_id(),
674+
);
675+
match timeout_fut.await {
676+
Ok(Some(Ok((id, _)))) => {
677+
log_trace!(self.logger, "Stopped background task with id {}", id);
678+
},
679+
Ok(Some(Err(e))) => {
680+
tasks.abort_all();
681+
log_trace!(self.logger, "Stopping background task failed: {}", e);
682+
break;
683+
},
684+
Ok(None) => {
685+
log_debug!(self.logger, "Stopped all background tasks");
686+
break;
687+
},
688+
Err(e) => {
689+
tasks.abort_all();
690+
log_error!(
691+
self.logger,
692+
"Stopping background task timed out: {}",
693+
e
694+
);
695+
break;
696+
},
697+
}
698+
}
699+
})
700+
});
701+
}
702+
646703
// Wait until background processing stopped, at least until a timeout is reached.
647704
if let Some(background_processor_task) =
648705
self.background_processor_task.lock().unwrap().take()

0 commit comments

Comments
 (0)