@@ -32,6 +32,7 @@ use crate::util::config::UserConfig;
32
32
use crate :: util:: enforcing_trait_impls:: { EnforcingSigner , EnforcementState } ;
33
33
use crate :: util:: logger:: { Logger , Level , Record } ;
34
34
use crate :: util:: ser:: { Readable , ReadableArgs , Writer , Writeable } ;
35
+ use crate :: util:: persist:: KVStore ;
35
36
36
37
use bitcoin:: EcdsaSighashType ;
37
38
use bitcoin:: blockdata:: constants:: ChainHash ;
@@ -56,7 +57,7 @@ use crate::prelude::*;
56
57
use core:: cell:: RefCell ;
57
58
use core:: ops:: DerefMut ;
58
59
use core:: time:: Duration ;
59
- use crate :: sync:: { Mutex , Arc } ;
60
+ use crate :: sync:: { Mutex , Arc , RwLock } ;
60
61
use core:: sync:: atomic:: { AtomicBool , AtomicUsize , Ordering } ;
61
62
use core:: mem;
62
63
use bitcoin:: bech32:: u5;
@@ -316,6 +317,98 @@ impl<Signer: sign::WriteableEcdsaChannelSigner> chainmonitor::Persist<Signer> fo
316
317
}
317
318
}
318
319
320
+ pub ( crate ) struct TestStore {
321
+ persisted_bytes : RwLock < HashMap < String , HashMap < String , Arc < RwLock < Vec < u8 > > > > > > ,
322
+ did_persist : Arc < AtomicBool > ,
323
+ }
324
+
325
+ impl TestStore {
326
+ pub fn new ( ) -> Self {
327
+ let persisted_bytes = RwLock :: new ( HashMap :: new ( ) ) ;
328
+ let did_persist = Arc :: new ( AtomicBool :: new ( false ) ) ;
329
+ Self { persisted_bytes, did_persist }
330
+ }
331
+
332
+ pub fn get_persisted_bytes ( & self , namespace : & str , key : & str ) -> Option < Vec < u8 > > {
333
+ if let Some ( outer_ref) = self . persisted_bytes . read ( ) . unwrap ( ) . get ( namespace) {
334
+ if let Some ( inner_ref) = outer_ref. get ( key) {
335
+ let locked = inner_ref. read ( ) . unwrap ( ) ;
336
+ return Some ( ( * locked) . clone ( ) ) ;
337
+ }
338
+ }
339
+ None
340
+ }
341
+
342
+ pub fn get_and_clear_did_persist ( & self ) -> bool {
343
+ self . did_persist . swap ( false , Ordering :: Relaxed )
344
+ }
345
+ }
346
+
347
+ impl KVStore for TestStore {
348
+ type Reader = TestReader ;
349
+
350
+ fn read ( & self , namespace : & str , key : & str ) -> std:: io:: Result < Self :: Reader > {
351
+ if let Some ( outer_ref) = self . persisted_bytes . read ( ) . unwrap ( ) . get ( namespace) {
352
+ if let Some ( inner_ref) = outer_ref. get ( key) {
353
+ Ok ( TestReader :: new ( Arc :: clone ( inner_ref) ) )
354
+ } else {
355
+ let msg = format ! ( "Key not found: {}" , key) ;
356
+ Err ( std:: io:: Error :: new ( std:: io:: ErrorKind :: NotFound , msg) )
357
+ }
358
+ } else {
359
+ let msg = format ! ( "Namespace not found: {}" , namespace) ;
360
+ Err ( std:: io:: Error :: new ( std:: io:: ErrorKind :: NotFound , msg) )
361
+ }
362
+ }
363
+
364
+ fn write ( & self , namespace : & str , key : & str , buf : & [ u8 ] ) -> std:: io:: Result < ( ) > {
365
+ let mut guard = self . persisted_bytes . write ( ) . unwrap ( ) ;
366
+ let outer_e = guard. entry ( namespace. to_string ( ) ) . or_insert ( HashMap :: new ( ) ) ;
367
+ let inner_e = outer_e. entry ( key. to_string ( ) ) . or_insert ( Arc :: new ( RwLock :: new ( Vec :: new ( ) ) ) ) ;
368
+
369
+ let mut guard = inner_e. write ( ) . unwrap ( ) ;
370
+ guard. write_all ( buf) ?;
371
+ self . did_persist . store ( true , Ordering :: SeqCst ) ;
372
+ Ok ( ( ) )
373
+ }
374
+
375
+ fn remove ( & self , namespace : & str , key : & str ) -> std:: io:: Result < ( ) > {
376
+ match self . persisted_bytes . write ( ) . unwrap ( ) . entry ( namespace. to_string ( ) ) {
377
+ hash_map:: Entry :: Occupied ( mut e) => {
378
+ self . did_persist . store ( true , Ordering :: SeqCst ) ;
379
+ e. get_mut ( ) . remove ( & key. to_string ( ) ) ;
380
+ Ok ( ( ) )
381
+ }
382
+ hash_map:: Entry :: Vacant ( _) => Ok ( ( ) ) ,
383
+ }
384
+ }
385
+
386
+ fn list ( & self , namespace : & str ) -> std:: io:: Result < Vec < String > > {
387
+ match self . persisted_bytes . write ( ) . unwrap ( ) . entry ( namespace. to_string ( ) ) {
388
+ hash_map:: Entry :: Occupied ( e) => Ok ( e. get ( ) . keys ( ) . cloned ( ) . collect ( ) ) ,
389
+ hash_map:: Entry :: Vacant ( _) => Ok ( Vec :: new ( ) ) ,
390
+ }
391
+ }
392
+ }
393
+
394
+ pub struct TestReader {
395
+ entry_ref : Arc < RwLock < Vec < u8 > > > ,
396
+ }
397
+
398
+ impl TestReader {
399
+ pub fn new ( entry_ref : Arc < RwLock < Vec < u8 > > > ) -> Self {
400
+ Self { entry_ref }
401
+ }
402
+ }
403
+
404
+ impl io:: Read for TestReader {
405
+ fn read ( & mut self , buf : & mut [ u8 ] ) -> std:: io:: Result < usize > {
406
+ let bytes = self . entry_ref . read ( ) . unwrap ( ) . clone ( ) ;
407
+ let mut reader = io:: Cursor :: new ( bytes) ;
408
+ reader. read ( buf)
409
+ }
410
+ }
411
+
319
412
pub struct TestBroadcaster {
320
413
pub txn_broadcasted : Mutex < Vec < Transaction > > ,
321
414
pub blocks : Arc < Mutex < Vec < ( Block , u32 ) > > > ,
0 commit comments