@@ -36,6 +36,7 @@ use bitcoin::{BlockHash, ScriptBuf, Transaction, Txid};
3636use crate :: sync:: Arc ;
3737use core:: future:: Future ;
3838use core:: ops:: Deref ;
39+ use core:: sync:: atomic:: { AtomicBool , Ordering } ;
3940use core:: task;
4041
4142use super :: async_poll:: dummy_waker;
@@ -350,7 +351,8 @@ where
350351 L :: Target : Logger ,
351352 O :: Target : OutputSpender ,
352353{
353- sweeper_state : Mutex < RuntimeSweeperState > ,
354+ sweeper_state : Mutex < SweeperState > ,
355+ pending_sweep : AtomicBool ,
354356 broadcaster : B ,
355357 fee_estimator : E ,
356358 chain_data_source : Option < F > ,
@@ -380,12 +382,10 @@ where
380382 output_spender : O , change_destination_source : D , kv_store : K , logger : L ,
381383 ) -> Self {
382384 let outputs = Vec :: new ( ) ;
383- let sweeper_state = Mutex :: new ( RuntimeSweeperState {
384- persistent : SweeperState { outputs, best_block } ,
385- sweep_pending : false ,
386- } ) ;
385+ let sweeper_state = Mutex :: new ( SweeperState { outputs, best_block } ) ;
387386 Self {
388387 sweeper_state,
388+ pending_sweep : AtomicBool :: new ( false ) ,
389389 broadcaster,
390390 fee_estimator,
391391 chain_data_source,
@@ -427,7 +427,7 @@ where
427427 return Ok ( ( ) ) ;
428428 }
429429
430- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
430+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
431431 for descriptor in relevant_descriptors {
432432 let output_info = TrackedSpendableOutput {
433433 descriptor,
@@ -444,20 +444,20 @@ where
444444
445445 state_lock. outputs . push ( output_info) ;
446446 }
447- self . persist_state ( & state_lock) . map_err ( |e| {
447+ self . persist_state ( & * state_lock) . map_err ( |e| {
448448 log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
449449 } )
450450 }
451451
452452 /// Returns a list of the currently tracked spendable outputs.
453453 pub fn tracked_spendable_outputs ( & self ) -> Vec < TrackedSpendableOutput > {
454- self . sweeper_state . lock ( ) . unwrap ( ) . persistent . outputs . clone ( )
454+ self . sweeper_state . lock ( ) . unwrap ( ) . outputs . clone ( )
455455 }
456456
457457 /// Gets the latest best block which was connected either via the [`Listen`] or
458458 /// [`Confirm`] interfaces.
459459 pub fn current_best_block ( & self ) -> BestBlock {
460- self . sweeper_state . lock ( ) . unwrap ( ) . persistent . best_block
460+ self . sweeper_state . lock ( ) . unwrap ( ) . best_block
461461 }
462462
463463 /// Regenerates and broadcasts the spending transaction for any outputs that are pending
@@ -481,24 +481,29 @@ where
481481 true
482482 } ;
483483
484+ // Prevent concurrent sweeps.
485+ if self . pending_sweep . load ( Ordering :: Relaxed ) {
486+ return Ok ( ( ) ) ;
487+ }
488+
484489 // See if there is anything to sweep before requesting a change address.
485490 {
486- let mut sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
491+ let sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
487492
488- // Prevent concurrent sweeping.
489- if sweeper_state. sweep_pending {
490- return Ok ( ( ) ) ;
491- }
492-
493- let cur_height = sweeper_state. persistent . best_block . height ;
494- let has_respends =
495- sweeper_state. persistent . outputs . iter ( ) . any ( |o| filter_fn ( o, cur_height) ) ;
493+ let cur_height = sweeper_state. best_block . height ;
494+ let has_respends = sweeper_state. outputs . iter ( ) . any ( |o| filter_fn ( o, cur_height) ) ;
496495 if !has_respends {
497496 return Ok ( ( ) ) ;
498497 }
498+ }
499499
500- // There is something to sweep. Block concurrent sweeps.
501- sweeper_state. sweep_pending = true ;
500+ // Mark sweep pending, if no other thread did so already.
501+ if self
502+ . pending_sweep
503+ . compare_exchange ( false , true , Ordering :: Acquire , Ordering :: Relaxed )
504+ . is_err ( )
505+ {
506+ return Ok ( ( ) ) ;
502507 }
503508
504509 // Request a new change address outside of the mutex to avoid the mutex crossing await.
@@ -509,10 +514,7 @@ where
509514 {
510515 let mut runtime_sweeper_state = self . sweeper_state . lock ( ) . unwrap ( ) ;
511516
512- // Always allow a new sweep after this spend, also in the error case.
513- runtime_sweeper_state. sweep_pending = false ;
514-
515- let sweeper_state = & mut runtime_sweeper_state. persistent ;
517+ let sweeper_state = & mut runtime_sweeper_state;
516518
517519 let change_destination_script = change_destination_script_result?;
518520
@@ -527,6 +529,8 @@ where
527529 . collect ( ) ;
528530
529531 if respend_descriptors. is_empty ( ) {
532+ self . pending_sweep . store ( false , Ordering :: Release ) ;
533+
530534 // It could be that a tx confirmed and there is now nothing to sweep anymore.
531535 return Ok ( ( ) ) ;
532536 }
@@ -545,6 +549,8 @@ where
545549 spending_tx
546550 } ,
547551 Err ( e) => {
552+ self . pending_sweep . store ( false , Ordering :: Release ) ;
553+
548554 log_error ! ( self . logger, "Error spending outputs: {:?}" , e) ;
549555 return Ok ( ( ) ) ;
550556 } ,
@@ -570,6 +576,8 @@ where
570576 self . broadcaster . broadcast_transactions ( & [ & spending_tx] ) ;
571577 }
572578
579+ self . pending_sweep . store ( false , Ordering :: Release ) ;
580+
573581 Ok ( ( ) )
574582 }
575583
@@ -668,22 +676,22 @@ where
668676 fn filtered_block_connected (
669677 & self , header : & Header , txdata : & chain:: transaction:: TransactionData , height : u32 ,
670678 ) {
671- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
679+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
672680 assert_eq ! ( state_lock. best_block. block_hash, header. prev_blockhash,
673681 "Blocks must be connected in chain-order - the connected header must build on the last connected header" ) ;
674682 assert_eq ! ( state_lock. best_block. height, height - 1 ,
675683 "Blocks must be connected in chain-order - the connected block height must be one greater than the previous height" ) ;
676684
677- self . transactions_confirmed_internal ( state_lock, header, txdata, height) ;
678- self . best_block_updated_internal ( state_lock, header, height) ;
685+ self . transactions_confirmed_internal ( & mut * state_lock, header, txdata, height) ;
686+ self . best_block_updated_internal ( & mut * state_lock, header, height) ;
679687
680- let _ = self . persist_state ( & state_lock) . map_err ( |e| {
688+ let _ = self . persist_state ( & * state_lock) . map_err ( |e| {
681689 log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
682690 } ) ;
683691 }
684692
685693 fn block_disconnected ( & self , header : & Header , height : u32 ) {
686- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
694+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
687695
688696 let new_height = height - 1 ;
689697 let block_hash = header. block_hash ( ) ;
@@ -721,15 +729,15 @@ where
721729 fn transactions_confirmed (
722730 & self , header : & Header , txdata : & chain:: transaction:: TransactionData , height : u32 ,
723731 ) {
724- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
725- self . transactions_confirmed_internal ( state_lock, header, txdata, height) ;
726- self . persist_state ( state_lock) . unwrap_or_else ( |e| {
732+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
733+ self . transactions_confirmed_internal ( & mut * state_lock, header, txdata, height) ;
734+ self . persist_state ( & * state_lock) . unwrap_or_else ( |e| {
727735 log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
728736 } ) ;
729737 }
730738
731739 fn transaction_unconfirmed ( & self , txid : & Txid ) {
732- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
740+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
733741
734742 // Get what height was unconfirmed.
735743 let unconf_height = state_lock
@@ -746,22 +754,22 @@ where
746754 . filter ( |o| o. status . confirmation_height ( ) >= Some ( unconf_height) )
747755 . for_each ( |o| o. status . unconfirmed ( ) ) ;
748756
749- self . persist_state ( state_lock) . unwrap_or_else ( |e| {
757+ self . persist_state ( & * state_lock) . unwrap_or_else ( |e| {
750758 log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
751759 } ) ;
752760 }
753761 }
754762
755763 fn best_block_updated ( & self , header : & Header , height : u32 ) {
756- let state_lock = & mut self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
757- self . best_block_updated_internal ( state_lock, header, height) ;
758- let _ = self . persist_state ( state_lock) . map_err ( |e| {
764+ let mut state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
765+ self . best_block_updated_internal ( & mut * state_lock, header, height) ;
766+ let _ = self . persist_state ( & * state_lock) . map_err ( |e| {
759767 log_error ! ( self . logger, "Error persisting OutputSweeper: {:?}" , e) ;
760768 } ) ;
761769 }
762770
763771 fn get_relevant_txids ( & self ) -> Vec < ( Txid , u32 , Option < BlockHash > ) > {
764- let state_lock = & self . sweeper_state . lock ( ) . unwrap ( ) . persistent ;
772+ let state_lock = self . sweeper_state . lock ( ) . unwrap ( ) ;
765773 state_lock
766774 . outputs
767775 . iter ( )
@@ -782,11 +790,6 @@ where
782790 }
783791}
784792
785- struct RuntimeSweeperState {
786- persistent : SweeperState ,
787- sweep_pending : bool ,
788- }
789-
790793#[ derive( Debug , Clone ) ]
791794struct SweeperState {
792795 outputs : Vec < TrackedSpendableOutput > ,
@@ -849,10 +852,10 @@ where
849852 }
850853 }
851854
852- let sweeper_state =
853- Mutex :: new ( RuntimeSweeperState { persistent : state, sweep_pending : false } ) ;
855+ let sweeper_state = Mutex :: new ( state) ;
854856 Ok ( Self {
855857 sweeper_state,
858+ pending_sweep : AtomicBool :: new ( false ) ,
856859 broadcaster,
857860 fee_estimator,
858861 chain_data_source,
@@ -898,12 +901,12 @@ where
898901 }
899902 }
900903
901- let sweeper_state =
902- Mutex :: new ( RuntimeSweeperState { persistent : state, sweep_pending : false } ) ;
904+ let sweeper_state = Mutex :: new ( state) ;
903905 Ok ( (
904906 best_block,
905907 OutputSweeper {
906908 sweeper_state,
909+ pending_sweep : AtomicBool :: new ( false ) ,
907910 broadcaster,
908911 fee_estimator,
909912 chain_data_source,
0 commit comments