@@ -128,8 +128,8 @@ pub use builder::NodeBuilder as Builder;
128128use chain:: ChainSource ;
129129use config:: {
130130 default_user_config, may_announce_channel, ChannelConfig , Config ,
131- LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS , NODE_ANN_BCAST_INTERVAL , PEER_RECONNECTION_INTERVAL ,
132- RGS_SYNC_INTERVAL ,
131+ BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS , LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS ,
132+ NODE_ANN_BCAST_INTERVAL , PEER_RECONNECTION_INTERVAL , RGS_SYNC_INTERVAL ,
133133} ;
134134use connection:: ConnectionManager ;
135135use event:: { EventHandler , EventQueue } ;
@@ -180,6 +180,8 @@ pub struct Node {
180180 runtime : Arc < RwLock < Option < Arc < tokio:: runtime:: Runtime > > > > ,
181181 stop_sender : tokio:: sync:: watch:: Sender < ( ) > ,
182182 background_processor_task : Mutex < Option < tokio:: task:: JoinHandle < ( ) > > > ,
183+ background_tasks : Mutex < Option < tokio:: task:: JoinSet < ( ) > > > ,
184+ cancellable_background_tasks : Mutex < Option < tokio:: task:: JoinSet < ( ) > > > ,
183185 config : Arc < Config > ,
184186 wallet : Arc < Wallet > ,
185187 chain_source : Arc < ChainSource > ,
@@ -233,6 +235,10 @@ impl Node {
233235 return Err ( Error :: AlreadyRunning ) ;
234236 }
235237
238+ let mut background_tasks = tokio:: task:: JoinSet :: new ( ) ;
239+ let mut cancellable_background_tasks = tokio:: task:: JoinSet :: new ( ) ;
240+ let runtime_handle = runtime. handle ( ) ;
241+
236242 log_info ! (
237243 self . logger,
238244 "Starting up LDK Node with node ID {} on network: {}" ,
@@ -259,19 +265,27 @@ impl Node {
259265 let sync_cman = Arc :: clone ( & self . channel_manager ) ;
260266 let sync_cmon = Arc :: clone ( & self . chain_monitor ) ;
261267 let sync_sweeper = Arc :: clone ( & self . output_sweeper ) ;
262- runtime. spawn ( async move {
263- chain_source
264- . continuously_sync_wallets ( stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
265- . await ;
266- } ) ;
268+ background_tasks. spawn_on (
269+ async move {
270+ chain_source
271+ . continuously_sync_wallets (
272+ stop_sync_receiver,
273+ sync_cman,
274+ sync_cmon,
275+ sync_sweeper,
276+ )
277+ . await ;
278+ } ,
279+ runtime_handle,
280+ ) ;
267281
268282 if self . gossip_source . is_rgs ( ) {
269283 let gossip_source = Arc :: clone ( & self . gossip_source ) ;
270284 let gossip_sync_store = Arc :: clone ( & self . kv_store ) ;
271285 let gossip_sync_logger = Arc :: clone ( & self . logger ) ;
272286 let gossip_node_metrics = Arc :: clone ( & self . node_metrics ) ;
273287 let mut stop_gossip_sync = self . stop_sender . subscribe ( ) ;
274- runtime . spawn ( async move {
288+ cancellable_background_tasks . spawn_on ( async move {
275289 let mut interval = tokio:: time:: interval ( RGS_SYNC_INTERVAL ) ;
276290 loop {
277291 tokio:: select! {
@@ -312,7 +326,7 @@ impl Node {
312326 }
313327 }
314328 }
315- } ) ;
329+ } , runtime_handle ) ;
316330 }
317331
318332 if let Some ( listening_addresses) = & self . config . listening_addresses {
@@ -338,7 +352,7 @@ impl Node {
338352 bind_addrs. extend ( resolved_address) ;
339353 }
340354
341- runtime . spawn ( async move {
355+ cancellable_background_tasks . spawn_on ( async move {
342356 {
343357 let listener =
344358 tokio:: net:: TcpListener :: bind ( & * bind_addrs) . await
@@ -357,7 +371,7 @@ impl Node {
357371 _ = stop_listen. changed( ) => {
358372 log_debug!(
359373 listening_logger,
360- "Stopping listening to inbound connections." ,
374+ "Stopping listening to inbound connections."
361375 ) ;
362376 break ;
363377 }
@@ -376,7 +390,7 @@ impl Node {
376390 }
377391
378392 listening_indicator. store ( false , Ordering :: Release ) ;
379- } ) ;
393+ } , runtime_handle ) ;
380394 }
381395
382396 // Regularly reconnect to persisted peers.
@@ -385,15 +399,15 @@ impl Node {
385399 let connect_logger = Arc :: clone ( & self . logger ) ;
386400 let connect_peer_store = Arc :: clone ( & self . peer_store ) ;
387401 let mut stop_connect = self . stop_sender . subscribe ( ) ;
388- runtime . spawn ( async move {
402+ cancellable_background_tasks . spawn_on ( async move {
389403 let mut interval = tokio:: time:: interval ( PEER_RECONNECTION_INTERVAL ) ;
390404 interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
391405 loop {
392406 tokio:: select! {
393407 _ = stop_connect. changed( ) => {
394408 log_debug!(
395409 connect_logger,
396- "Stopping reconnecting known peers." ,
410+ "Stopping reconnecting known peers."
397411 ) ;
398412 return ;
399413 }
@@ -413,7 +427,7 @@ impl Node {
413427 }
414428 }
415429 }
416- } ) ;
430+ } , runtime_handle ) ;
417431
418432 // Regularly broadcast node announcements.
419433 let bcast_cm = Arc :: clone ( & self . channel_manager ) ;
@@ -425,7 +439,7 @@ impl Node {
425439 let mut stop_bcast = self . stop_sender . subscribe ( ) ;
426440 let node_alias = self . config . node_alias . clone ( ) ;
427441 if may_announce_channel ( & self . config ) . is_ok ( ) {
428- runtime . spawn ( async move {
442+ cancellable_background_tasks . spawn_on ( async move {
429443 // We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
430444 #[ cfg( not( test) ) ]
431445 let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 30 ) ) ;
@@ -496,7 +510,7 @@ impl Node {
496510 }
497511 }
498512 }
499- } ) ;
513+ } , runtime_handle ) ;
500514 }
501515
502516 let mut stop_tx_bcast = self . stop_sender . subscribe ( ) ;
@@ -605,24 +619,33 @@ impl Node {
605619 let mut stop_liquidity_handler = self . stop_sender . subscribe ( ) ;
606620 let liquidity_handler = Arc :: clone ( & liquidity_source) ;
607621 let liquidity_logger = Arc :: clone ( & self . logger ) ;
608- runtime. spawn ( async move {
609- loop {
610- tokio:: select! {
611- _ = stop_liquidity_handler. changed( ) => {
612- log_debug!(
613- liquidity_logger,
614- "Stopping processing liquidity events." ,
615- ) ;
616- return ;
622+ background_tasks. spawn_on (
623+ async move {
624+ loop {
625+ tokio:: select! {
626+ _ = stop_liquidity_handler. changed( ) => {
627+ log_debug!(
628+ liquidity_logger,
629+ "Stopping processing liquidity events." ,
630+ ) ;
631+ return ;
632+ }
633+ _ = liquidity_handler. handle_next_event( ) => { }
617634 }
618- _ = liquidity_handler. handle_next_event( ) => { }
619635 }
620- }
621- } ) ;
636+ } ,
637+ runtime_handle,
638+ ) ;
622639 }
623640
624641 * runtime_lock = Some ( runtime) ;
625642
643+ debug_assert ! ( self . background_tasks. lock( ) . unwrap( ) . is_none( ) ) ;
644+ * self . background_tasks . lock ( ) . unwrap ( ) = Some ( background_tasks) ;
645+
646+ debug_assert ! ( self . cancellable_background_tasks. lock( ) . unwrap( ) . is_none( ) ) ;
647+ * self . cancellable_background_tasks . lock ( ) . unwrap ( ) = Some ( cancellable_background_tasks) ;
648+
626649 log_info ! ( self . logger, "Startup complete." ) ;
627650 Ok ( ( ) )
628651 }
@@ -653,6 +676,17 @@ impl Node {
653676 } ,
654677 }
655678
679+ // Cancel cancellable background tasks
680+ if let Some ( mut tasks) = self . cancellable_background_tasks . lock ( ) . unwrap ( ) . take ( ) {
681+ let runtime_2 = Arc :: clone ( & runtime) ;
682+ tasks. abort_all ( ) ;
683+ tokio:: task:: block_in_place ( move || {
684+ runtime_2. block_on ( async { while let Some ( _) = tasks. join_next ( ) . await { } } )
685+ } ) ;
686+ } else {
687+ debug_assert ! ( false , "Expected some cancellable background tasks" ) ;
688+ } ;
689+
656690 // Disconnect all peers.
657691 self . peer_manager . disconnect_all_peers ( ) ;
658692 log_debug ! ( self . logger, "Disconnected all network peers." ) ;
@@ -661,6 +695,46 @@ impl Node {
661695 self . chain_source . stop ( ) ;
662696 log_debug ! ( self . logger, "Stopped chain sources." ) ;
663697
698+ // Wait until non-cancellable background tasks (mod LDK's background processor) are done.
699+ let runtime_3 = Arc :: clone ( & runtime) ;
700+ if let Some ( mut tasks) = self . background_tasks . lock ( ) . unwrap ( ) . take ( ) {
701+ tokio:: task:: block_in_place ( move || {
702+ runtime_3. block_on ( async {
703+ loop {
704+ let timeout_fut = tokio:: time:: timeout (
705+ Duration :: from_secs ( BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS ) ,
706+ tasks. join_next_with_id ( ) ,
707+ ) ;
708+ match timeout_fut. await {
709+ Ok ( Some ( Ok ( ( id, _) ) ) ) => {
710+ log_trace ! ( self . logger, "Stopped background task with id {}" , id) ;
711+ } ,
712+ Ok ( Some ( Err ( e) ) ) => {
713+ tasks. abort_all ( ) ;
714+ log_trace ! ( self . logger, "Stopping background task failed: {}" , e) ;
715+ break ;
716+ } ,
717+ Ok ( None ) => {
718+ log_debug ! ( self . logger, "Stopped all background tasks" ) ;
719+ break ;
720+ } ,
721+ Err ( e) => {
722+ tasks. abort_all ( ) ;
723+ log_error ! (
724+ self . logger,
725+ "Stopping background task timed out: {}" ,
726+ e
727+ ) ;
728+ break ;
729+ } ,
730+ }
731+ }
732+ } )
733+ } ) ;
734+ } else {
735+ debug_assert ! ( false , "Expected some background tasks" ) ;
736+ } ;
737+
664738 // Wait until background processing stopped, at least until a timeout is reached.
665739 if let Some ( background_processor_task) =
666740 self . background_processor_task . lock ( ) . unwrap ( ) . take ( )
@@ -694,7 +768,9 @@ impl Node {
694768 log_error ! ( self . logger, "Stopping event handling timed out: {}" , e) ;
695769 } ,
696770 }
697- }
771+ } else {
772+ debug_assert ! ( false , "Expected a background processing task" ) ;
773+ } ;
698774
699775 #[ cfg( tokio_unstable) ]
700776 {
0 commit comments