@@ -128,8 +128,8 @@ pub use builder::NodeBuilder as Builder;
128128use chain:: ChainSource ;
129129use config:: {
130130 default_user_config, may_announce_channel, ChannelConfig , Config ,
131- LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS , NODE_ANN_BCAST_INTERVAL , PEER_RECONNECTION_INTERVAL ,
132- RGS_SYNC_INTERVAL ,
131+ BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS , LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS ,
132+ NODE_ANN_BCAST_INTERVAL , PEER_RECONNECTION_INTERVAL , RGS_SYNC_INTERVAL ,
133133} ;
134134use connection:: ConnectionManager ;
135135use event:: { EventHandler , EventQueue } ;
@@ -180,6 +180,8 @@ pub struct Node {
180180 runtime : Arc < RwLock < Option < Arc < tokio:: runtime:: Runtime > > > > ,
181181 stop_sender : tokio:: sync:: watch:: Sender < ( ) > ,
182182 background_processor_task : Mutex < Option < tokio:: task:: JoinHandle < ( ) > > > ,
183+ background_tasks : Mutex < Option < tokio:: task:: JoinSet < ( ) > > > ,
184+ cancellable_background_tasks : Mutex < Option < tokio:: task:: JoinSet < ( ) > > > ,
183185 config : Arc < Config > ,
184186 wallet : Arc < Wallet > ,
185187 chain_source : Arc < ChainSource > ,
@@ -233,6 +235,10 @@ impl Node {
233235 return Err ( Error :: AlreadyRunning ) ;
234236 }
235237
238+ let mut background_tasks = tokio:: task:: JoinSet :: new ( ) ;
239+ let mut cancellable_background_tasks = tokio:: task:: JoinSet :: new ( ) ;
240+ let runtime_handle = runtime. handle ( ) ;
241+
236242 log_info ! (
237243 self . logger,
238244 "Starting up LDK Node with node ID {} on network: {}" ,
@@ -259,19 +265,27 @@ impl Node {
259265 let sync_cman = Arc :: clone ( & self . channel_manager ) ;
260266 let sync_cmon = Arc :: clone ( & self . chain_monitor ) ;
261267 let sync_sweeper = Arc :: clone ( & self . output_sweeper ) ;
262- runtime. spawn ( async move {
263- chain_source
264- . continuously_sync_wallets ( stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
265- . await ;
266- } ) ;
268+ background_tasks. spawn_on (
269+ async move {
270+ chain_source
271+ . continuously_sync_wallets (
272+ stop_sync_receiver,
273+ sync_cman,
274+ sync_cmon,
275+ sync_sweeper,
276+ )
277+ . await ;
278+ } ,
279+ runtime_handle,
280+ ) ;
267281
268282 if self . gossip_source . is_rgs ( ) {
269283 let gossip_source = Arc :: clone ( & self . gossip_source ) ;
270284 let gossip_sync_store = Arc :: clone ( & self . kv_store ) ;
271285 let gossip_sync_logger = Arc :: clone ( & self . logger ) ;
272286 let gossip_node_metrics = Arc :: clone ( & self . node_metrics ) ;
273287 let mut stop_gossip_sync = self . stop_sender . subscribe ( ) ;
274- runtime . spawn ( async move {
288+ background_tasks . spawn_on ( async move {
275289 let mut interval = tokio:: time:: interval ( RGS_SYNC_INTERVAL ) ;
276290 loop {
277291 tokio:: select! {
@@ -312,7 +326,7 @@ impl Node {
312326 }
313327 }
314328 }
315- } ) ;
329+ } , runtime_handle ) ;
316330 }
317331
318332 if let Some ( listening_addresses) = & self . config . listening_addresses {
@@ -338,7 +352,7 @@ impl Node {
338352 bind_addrs. extend ( resolved_address) ;
339353 }
340354
341- runtime . spawn ( async move {
355+ cancellable_background_tasks . spawn_on ( async move {
342356 {
343357 let listener =
344358 tokio:: net:: TcpListener :: bind ( & * bind_addrs) . await
@@ -357,7 +371,7 @@ impl Node {
357371 _ = stop_listen. changed( ) => {
358372 log_debug!(
359373 listening_logger,
360- "Stopping listening to inbound connections." ,
374+ "Stopping listening to inbound connections."
361375 ) ;
362376 break ;
363377 }
@@ -376,7 +390,7 @@ impl Node {
376390 }
377391
378392 listening_indicator. store ( false , Ordering :: Release ) ;
379- } ) ;
393+ } , runtime_handle ) ;
380394 }
381395
382396 // Regularly reconnect to persisted peers.
@@ -385,15 +399,15 @@ impl Node {
385399 let connect_logger = Arc :: clone ( & self . logger ) ;
386400 let connect_peer_store = Arc :: clone ( & self . peer_store ) ;
387401 let mut stop_connect = self . stop_sender . subscribe ( ) ;
388- runtime . spawn ( async move {
402+ cancellable_background_tasks . spawn_on ( async move {
389403 let mut interval = tokio:: time:: interval ( PEER_RECONNECTION_INTERVAL ) ;
390404 interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
391405 loop {
392406 tokio:: select! {
393407 _ = stop_connect. changed( ) => {
394408 log_debug!(
395409 connect_logger,
396- "Stopping reconnecting known peers." ,
410+ "Stopping reconnecting known peers."
397411 ) ;
398412 return ;
399413 }
@@ -413,7 +427,7 @@ impl Node {
413427 }
414428 }
415429 }
416- } ) ;
430+ } , runtime_handle ) ;
417431
418432 // Regularly broadcast node announcements.
419433 let bcast_cm = Arc :: clone ( & self . channel_manager ) ;
@@ -425,7 +439,7 @@ impl Node {
425439 let mut stop_bcast = self . stop_sender . subscribe ( ) ;
426440 let node_alias = self . config . node_alias . clone ( ) ;
427441 if may_announce_channel ( & self . config ) . is_ok ( ) {
428- runtime . spawn ( async move {
442+ background_tasks . spawn_on ( async move {
429443 // We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
430444 #[ cfg( not( test) ) ]
431445 let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 30 ) ) ;
@@ -496,7 +510,7 @@ impl Node {
496510 }
497511 }
498512 }
499- } ) ;
513+ } , runtime_handle ) ;
500514 }
501515
502516 let mut stop_tx_bcast = self . stop_sender . subscribe ( ) ;
@@ -605,24 +619,33 @@ impl Node {
605619 let mut stop_liquidity_handler = self . stop_sender . subscribe ( ) ;
606620 let liquidity_handler = Arc :: clone ( & liquidity_source) ;
607621 let liquidity_logger = Arc :: clone ( & self . logger ) ;
608- runtime. spawn ( async move {
609- loop {
610- tokio:: select! {
611- _ = stop_liquidity_handler. changed( ) => {
612- log_debug!(
613- liquidity_logger,
614- "Stopping processing liquidity events." ,
615- ) ;
616- return ;
622+ background_tasks. spawn_on (
623+ async move {
624+ loop {
625+ tokio:: select! {
626+ _ = stop_liquidity_handler. changed( ) => {
627+ log_debug!(
628+ liquidity_logger,
629+ "Stopping processing liquidity events." ,
630+ ) ;
631+ return ;
632+ }
633+ _ = liquidity_handler. handle_next_event( ) => { }
617634 }
618- _ = liquidity_handler. handle_next_event( ) => { }
619635 }
620- }
621- } ) ;
636+ } ,
637+ runtime_handle,
638+ ) ;
622639 }
623640
624641 * runtime_lock = Some ( runtime) ;
625642
643+ debug_assert ! ( self . background_tasks. lock( ) . unwrap( ) . is_none( ) ) ;
644+ * self . background_tasks . lock ( ) . unwrap ( ) = Some ( background_tasks) ;
645+
646+ debug_assert ! ( self . cancellable_background_tasks. lock( ) . unwrap( ) . is_none( ) ) ;
647+ * self . cancellable_background_tasks . lock ( ) . unwrap ( ) = Some ( cancellable_background_tasks) ;
648+
626649 log_info ! ( self . logger, "Startup complete." ) ;
627650 Ok ( ( ) )
628651 }
@@ -661,6 +684,53 @@ impl Node {
661684 self . chain_source . stop ( ) ;
662685 log_debug ! ( self . logger, "Stopped chain sources." ) ;
663686
687+ // Cancel cancellable background tasks
688+ if let Some ( mut tasks) = self . cancellable_background_tasks . lock ( ) . unwrap ( ) . take ( ) {
689+ tasks. abort_all ( ) ;
690+ } else {
691+ debug_assert ! ( false , "Expected some cancellable background tasks" ) ;
692+ } ;
693+
694+ // Wait until non-cancellable background tasks (mod LDK's background processor) are done.
695+ let runtime_2 = Arc :: clone ( & runtime) ;
696+ if let Some ( mut tasks) = self . background_tasks . lock ( ) . unwrap ( ) . take ( ) {
697+ tokio:: task:: block_in_place ( move || {
698+ runtime_2. block_on ( async {
699+ loop {
700+ let timeout_fut = tokio:: time:: timeout (
701+ Duration :: from_secs ( BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS ) ,
702+ tasks. join_next_with_id ( ) ,
703+ ) ;
704+ match timeout_fut. await {
705+ Ok ( Some ( Ok ( ( id, _) ) ) ) => {
706+ log_trace ! ( self . logger, "Stopped background task with id {}" , id) ;
707+ } ,
708+ Ok ( Some ( Err ( e) ) ) => {
709+ tasks. abort_all ( ) ;
710+ log_trace ! ( self . logger, "Stopping background task failed: {}" , e) ;
711+ break ;
712+ } ,
713+ Ok ( None ) => {
714+ log_debug ! ( self . logger, "Stopped all background tasks" ) ;
715+ break ;
716+ } ,
717+ Err ( e) => {
718+ tasks. abort_all ( ) ;
719+ log_error ! (
720+ self . logger,
721+ "Stopping background task timed out: {}" ,
722+ e
723+ ) ;
724+ break ;
725+ } ,
726+ }
727+ }
728+ } )
729+ } ) ;
730+ } else {
731+ debug_assert ! ( false , "Expected some background tasks" ) ;
732+ } ;
733+
664734 // Wait until background processing stopped, at least until a timeout is reached.
665735 if let Some ( background_processor_task) =
666736 self . background_processor_task . lock ( ) . unwrap ( ) . take ( )
@@ -694,7 +764,9 @@ impl Node {
694764 log_error ! ( self . logger, "Stopping event handling timed out: {}" , e) ;
695765 } ,
696766 }
697- }
767+ } else {
768+ debug_assert ! ( false , "Expected a background processing task" ) ;
769+ } ;
698770
699771 #[ cfg( tokio_unstable) ]
700772 {
0 commit comments