@@ -89,6 +89,7 @@ use core::future::Future;
89
89
use core:: mem;
90
90
use core:: pin:: Pin ;
91
91
use core:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
92
+ use core:: task:: { Context , Poll , Waker } ;
92
93
use core:: time:: Duration ;
93
94
94
95
use bitcoin:: psbt:: Psbt ;
@@ -856,15 +857,93 @@ impl<Signer: sign::ecdsa::EcdsaChannelSigner> Persist<Signer> for TestPersister
856
857
}
857
858
}
858
859
860
+ // A simple multi-producer-single-consumer one-shot channel
861
+ type OneShotChannelState = Arc < Mutex < ( Option < Result < ( ) , io:: Error > > , Option < Waker > ) > > ;
862
+ struct OneShotChannel ( OneShotChannelState ) ;
863
+ impl Future for OneShotChannel {
864
+ type Output = Result < ( ) , io:: Error > ;
865
+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , io:: Error > > {
866
+ let mut state = self . 0 . lock ( ) . unwrap ( ) ;
867
+ // If the future is complete, take() the result and return it,
868
+ state. 0 . take ( ) . map ( |res| Poll :: Ready ( res) ) . unwrap_or_else ( || {
869
+ // otherwise, store the waker so that the future will be poll()ed again when the result
870
+ // is ready.
871
+ state. 1 = Some ( cx. waker ( ) . clone ( ) ) ;
872
+ Poll :: Pending
873
+ } )
874
+ }
875
+ }
876
+
877
+ /// An in-memory KVStore for testing.
878
+ ///
879
+ /// Sync writes always complete immediately while async writes always block until manually
880
+ /// completed with [`Self::complete_async_writes_through`] or [`Self::complete_all_async_writes`].
881
+ ///
882
+ /// Removes always complete immediately.
859
883
pub struct TestStore {
884
+ pending_async_writes : Mutex < HashMap < String , Vec < ( usize , OneShotChannelState , Vec < u8 > ) > > > ,
860
885
persisted_bytes : Mutex < HashMap < String , HashMap < String , Vec < u8 > > > > ,
861
886
read_only : bool ,
862
887
}
863
888
864
889
impl TestStore {
865
890
pub fn new ( read_only : bool ) -> Self {
891
+ let pending_async_writes = Mutex :: new ( new_hash_map ( ) ) ;
866
892
let persisted_bytes = Mutex :: new ( new_hash_map ( ) ) ;
867
- Self { persisted_bytes, read_only }
893
+ Self { pending_async_writes, persisted_bytes, read_only }
894
+ }
895
+
896
+ pub fn list_pending_async_writes (
897
+ & self , primary_namespace : & str , secondary_namespace : & str , key : & str ,
898
+ ) -> Vec < usize > {
899
+ let key = format ! ( "{primary_namespace}/{secondary_namespace}/{key}" ) ;
900
+ let writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
901
+ writes_lock
902
+ . get ( & key)
903
+ . map ( |v| v. iter ( ) . map ( |( id, _, _) | * id) . collect ( ) )
904
+ . unwrap_or ( Vec :: new ( ) )
905
+ }
906
+
907
+ /// Completes all pending async writes for the given namespace and key, up to and through the
908
+ /// given `write_id` (which can be fetched from [`Self::list_pending_async_writes`]).
909
+ pub fn complete_async_writes_through (
910
+ & self , primary_namespace : & str , secondary_namespace : & str , key : & str , write_id : usize ,
911
+ ) {
912
+ let prefix = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
913
+ let key = format ! ( "{primary_namespace}/{secondary_namespace}/{key}" ) ;
914
+
915
+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
916
+ let mut writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
917
+
918
+ let pending_writes = writes_lock. get_mut ( & key) . expect ( "No pending writes for given key" ) ;
919
+ pending_writes. retain ( |( id, res, data) | {
920
+ if * id <= write_id {
921
+ let namespace = persisted_lock. entry ( prefix. clone ( ) ) . or_insert ( new_hash_map ( ) ) ;
922
+ * namespace. entry ( key. to_string ( ) ) . or_default ( ) = data. clone ( ) ;
923
+ let mut future_state = res. lock ( ) . unwrap ( ) ;
924
+ future_state. 0 = Some ( Ok ( ( ) ) ) ;
925
+ if let Some ( waker) = future_state. 1 . take ( ) {
926
+ waker. wake ( ) ;
927
+ }
928
+ false
929
+ } else {
930
+ true
931
+ }
932
+ } ) ;
933
+ }
934
+
935
+ /// Completes all pending async writes on all namespaces and keys.
936
+ pub fn complete_all_async_writes ( & self ) {
937
+ let pending_writes: Vec < String > =
938
+ self . pending_async_writes . lock ( ) . unwrap ( ) . keys ( ) . cloned ( ) . collect ( ) ;
939
+ for key in pending_writes {
940
+ let mut levels = key. split ( "/" ) ;
941
+ let primary = levels. next ( ) . unwrap ( ) ;
942
+ let secondary = levels. next ( ) . unwrap ( ) ;
943
+ let key = levels. next ( ) . unwrap ( ) ;
944
+ assert ! ( levels. next( ) . is_none( ) ) ;
945
+ self . complete_async_writes_through ( primary, secondary, key, usize:: MAX ) ;
946
+ }
868
947
}
869
948
870
949
fn read_internal (
@@ -885,23 +964,6 @@ impl TestStore {
885
964
}
886
965
}
887
966
888
- fn write_internal (
889
- & self , primary_namespace : & str , secondary_namespace : & str , key : & str , buf : Vec < u8 > ,
890
- ) -> io:: Result < ( ) > {
891
- if self . read_only {
892
- return Err ( io:: Error :: new (
893
- io:: ErrorKind :: PermissionDenied ,
894
- "Cannot modify read-only store" ,
895
- ) ) ;
896
- }
897
- let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
898
-
899
- let prefixed = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
900
- let outer_e = persisted_lock. entry ( prefixed) . or_insert ( new_hash_map ( ) ) ;
901
- outer_e. insert ( key. to_string ( ) , buf) ;
902
- Ok ( ( ) )
903
- }
904
-
905
967
fn remove_internal (
906
968
& self , primary_namespace : & str , secondary_namespace : & str , key : & str , _lazy : bool ,
907
969
) -> io:: Result < ( ) > {
@@ -913,12 +975,23 @@ impl TestStore {
913
975
}
914
976
915
977
let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
978
+ let mut async_writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
916
979
917
980
let prefixed = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
918
981
if let Some ( outer_ref) = persisted_lock. get_mut ( & prefixed) {
919
982
outer_ref. remove ( & key. to_string ( ) ) ;
920
983
}
921
984
985
+ if let Some ( pending_writes) = async_writes_lock. remove ( & format ! ( "{prefixed}/{key}" ) ) {
986
+ for ( _, future, _) in pending_writes {
987
+ let mut future_lock = future. lock ( ) . unwrap ( ) ;
988
+ future_lock. 0 = Some ( Ok ( ( ) ) ) ;
989
+ if let Some ( waker) = future_lock. 1 . take ( ) {
990
+ waker. wake ( ) ;
991
+ }
992
+ }
993
+ }
994
+
922
995
Ok ( ( ) )
923
996
}
924
997
@@ -945,8 +1018,15 @@ impl KVStore for TestStore {
945
1018
fn write (
946
1019
& self , primary_namespace : & str , secondary_namespace : & str , key : & str , buf : Vec < u8 > ,
947
1020
) -> Pin < Box < dyn Future < Output = Result < ( ) , io:: Error > > + ' static + Send > > {
948
- let res = self . write_internal ( & primary_namespace, & secondary_namespace, & key, buf) ;
949
- Box :: pin ( async move { res } )
1021
+ let path = format ! ( "{primary_namespace}/{secondary_namespace}/{key}" ) ;
1022
+ let future = Arc :: new ( Mutex :: new ( ( None , None ) ) ) ;
1023
+
1024
+ let mut async_writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
1025
+ let pending_writes = async_writes_lock. entry ( path) . or_insert ( Vec :: new ( ) ) ;
1026
+ let new_id = pending_writes. last ( ) . map ( |( id, _, _) | id + 1 ) . unwrap_or ( 0 ) ;
1027
+ pending_writes. push ( ( new_id, Arc :: clone ( & future) , buf) ) ;
1028
+
1029
+ Box :: pin ( OneShotChannel ( future) )
950
1030
}
951
1031
fn remove (
952
1032
& self , primary_namespace : & str , secondary_namespace : & str , key : & str , lazy : bool ,
@@ -972,7 +1052,30 @@ impl KVStoreSync for TestStore {
972
1052
fn write (
973
1053
& self , primary_namespace : & str , secondary_namespace : & str , key : & str , buf : Vec < u8 > ,
974
1054
) -> io:: Result < ( ) > {
975
- self . write_internal ( primary_namespace, secondary_namespace, key, buf)
1055
+ if self . read_only {
1056
+ return Err ( io:: Error :: new (
1057
+ io:: ErrorKind :: PermissionDenied ,
1058
+ "Cannot modify read-only store" ,
1059
+ ) ) ;
1060
+ }
1061
+ let mut persisted_lock = self . persisted_bytes . lock ( ) . unwrap ( ) ;
1062
+ let mut async_writes_lock = self . pending_async_writes . lock ( ) . unwrap ( ) ;
1063
+
1064
+ let prefixed = format ! ( "{primary_namespace}/{secondary_namespace}" ) ;
1065
+ let async_writes_pending = async_writes_lock. remove ( & format ! ( "{prefixed}/{key}" ) ) ;
1066
+ let outer_e = persisted_lock. entry ( prefixed) . or_insert ( new_hash_map ( ) ) ;
1067
+ outer_e. insert ( key. to_string ( ) , buf) ;
1068
+
1069
+ if let Some ( pending_writes) = async_writes_pending {
1070
+ for ( _, future, _) in pending_writes {
1071
+ let mut future_lock = future. lock ( ) . unwrap ( ) ;
1072
+ future_lock. 0 = Some ( Ok ( ( ) ) ) ;
1073
+ if let Some ( waker) = future_lock. 1 . take ( ) {
1074
+ waker. wake ( ) ;
1075
+ }
1076
+ }
1077
+ }
1078
+ Ok ( ( ) )
976
1079
}
977
1080
978
1081
fn remove (
0 commit comments