1
+ use std:: sync:: Arc ;
2
+
1
3
use sqlite:: { Connection , Value } ;
2
4
5
+ use async_trait:: async_trait;
6
+
3
7
use mithril_common:: {
4
8
entities:: {
5
9
Beacon , Certificate , CertificateMetadata , Epoch , HexEncodedAgregateVerificationKey ,
@@ -9,9 +13,11 @@ use mithril_common::{
9
13
EntityCursor , HydrationError , Projection , Provider , SourceAlias , SqLiteEntity ,
10
14
WhereCondition ,
11
15
} ,
16
+ store:: adapter:: { AdapterError , StoreAdapter } ,
12
17
} ;
13
18
14
19
use mithril_common:: StdError ;
20
+ use tokio:: sync:: Mutex ;
15
21
16
22
/// Certificate record is the representation of a stored certificate.
17
23
#[ derive( Debug , PartialEq , Clone ) ]
@@ -288,6 +294,14 @@ impl<'client> CertificateRecordProvider<'client> {
288
294
289
295
Ok ( certificate_record)
290
296
}
297
+
298
+ /// Get all CertificateRecords.
299
+ pub fn get_all ( & self ) -> Result < EntityCursor < CertificateRecord > , StdError > {
300
+ let filters = WhereCondition :: default ( ) ;
301
+ let certificate_record = self . find ( filters) ?;
302
+
303
+ Ok ( certificate_record)
304
+ }
291
305
}
292
306
293
307
impl < ' client > Provider < ' client > for CertificateRecordProvider < ' client > {
@@ -377,6 +391,81 @@ impl<'conn> Provider<'conn> for InsertCertificateRecordProvider<'conn> {
377
391
}
378
392
}
379
393
394
+ /// Service to deal with certificate (read & write).
395
+ pub struct CertificateStoreAdapter {
396
+ connection : Arc < Mutex < Connection > > ,
397
+ }
398
+
399
+ impl CertificateStoreAdapter {
400
+ /// Create a new CertificateStoreAdapter service
401
+ pub fn new ( connection : Arc < Mutex < Connection > > ) -> Self {
402
+ Self { connection }
403
+ }
404
+ }
405
+
406
+ #[ async_trait]
407
+ impl StoreAdapter for CertificateStoreAdapter {
408
+ type Key = String ;
409
+ type Record = Certificate ;
410
+
411
+ async fn store_record (
412
+ & mut self ,
413
+ _key : & Self :: Key ,
414
+ record : & Self :: Record ,
415
+ ) -> Result < ( ) , AdapterError > {
416
+ let connection = & * self . connection . lock ( ) . await ;
417
+ let provider = InsertCertificateRecordProvider :: new ( connection) ;
418
+ let _certificate_record = provider
419
+ . persist ( record. to_owned ( ) . into ( ) )
420
+ . map_err ( |e| AdapterError :: GeneralError ( format ! ( "{e}" ) ) ) ?;
421
+
422
+ Ok ( ( ) )
423
+ }
424
+
425
+ async fn get_record ( & self , key : & Self :: Key ) -> Result < Option < Self :: Record > , AdapterError > {
426
+ let connection = & * self . connection . lock ( ) . await ;
427
+ let provider = CertificateRecordProvider :: new ( connection) ;
428
+ let mut cursor = provider
429
+ . get_by_certificate_id ( key. to_string ( ) )
430
+ . map_err ( |e| AdapterError :: GeneralError ( format ! ( "{e}" ) ) ) ?;
431
+ let certificate = cursor
432
+ . next ( )
433
+ . map ( |certificate_record| certificate_record. into ( ) ) ;
434
+
435
+ Ok ( certificate)
436
+ }
437
+
438
+ async fn record_exists ( & self , key : & Self :: Key ) -> Result < bool , AdapterError > {
439
+ Ok ( self . get_record ( key) . await ?. is_some ( ) )
440
+ }
441
+
442
+ async fn get_last_n_records (
443
+ & self ,
444
+ how_many : usize ,
445
+ ) -> Result < Vec < ( Self :: Key , Self :: Record ) > , AdapterError > {
446
+ Ok ( self
447
+ . get_iter ( )
448
+ . await ?
449
+ . take ( how_many)
450
+ . map ( |c| ( c. hash . to_owned ( ) , c) )
451
+ . collect ( ) )
452
+ }
453
+
454
+ async fn remove ( & mut self , _key : & Self :: Key ) -> Result < Option < Self :: Record > , AdapterError > {
455
+ unimplemented ! ( )
456
+ }
457
+
458
+ async fn get_iter ( & self ) -> Result < Box < dyn Iterator < Item = Self :: Record > + ' _ > , AdapterError > {
459
+ let connection = & * self . connection . lock ( ) . await ;
460
+ let provider = CertificateRecordProvider :: new ( connection) ;
461
+ let cursor = provider
462
+ . get_all ( )
463
+ . map_err ( |e| AdapterError :: GeneralError ( format ! ( "{e}" ) ) ) ?;
464
+ let certificates: Vec < Certificate > = cursor. map ( |c| c. into ( ) ) . collect ( ) ;
465
+ Ok ( Box :: new ( certificates. into_iter ( ) ) )
466
+ }
467
+ }
468
+
380
469
#[ cfg( test) ]
381
470
mod tests {
382
471
use crate :: database:: migration:: get_migrations;
@@ -573,7 +662,6 @@ mod tests {
573
662
574
663
let certificate_records: Vec < CertificateRecord > =
575
664
provider. get_by_epoch ( & Epoch ( 1 ) ) . unwrap ( ) . collect ( ) ;
576
-
577
665
let expected_certificate_records: Vec < CertificateRecord > = certificates
578
666
. iter ( )
579
667
. filter_map ( |c| ( c. beacon . epoch == Epoch ( 1 ) ) . then_some ( c. to_owned ( ) . into ( ) ) )
@@ -582,7 +670,6 @@ mod tests {
582
670
583
671
let certificate_records: Vec < CertificateRecord > =
584
672
provider. get_by_epoch ( & Epoch ( 3 ) ) . unwrap ( ) . collect ( ) ;
585
-
586
673
let expected_certificate_records: Vec < CertificateRecord > = certificates
587
674
. iter ( )
588
675
. filter_map ( |c| ( c. beacon . epoch == Epoch ( 3 ) ) . then_some ( c. to_owned ( ) . into ( ) ) )
@@ -591,6 +678,11 @@ mod tests {
591
678
592
679
let cursor = provider. get_by_epoch ( & Epoch ( 5 ) ) . unwrap ( ) ;
593
680
assert_eq ! ( 0 , cursor. count( ) ) ;
681
+
682
+ let certificate_records: Vec < CertificateRecord > = provider. get_all ( ) . unwrap ( ) . collect ( ) ;
683
+ let expected_certificate_records: Vec < CertificateRecord > =
684
+ certificates. iter ( ) . map ( |c| c. to_owned ( ) . into ( ) ) . collect ( ) ;
685
+ assert_eq ! ( expected_certificate_records, certificate_records) ;
594
686
}
595
687
596
688
#[ test]
@@ -608,4 +700,47 @@ mod tests {
608
700
assert_eq ! ( certificate_record, certificate_record_saved) ;
609
701
}
610
702
}
703
+
704
+ #[ tokio:: test]
705
+ async fn test_store_adapter ( ) {
706
+ let ( certificates, _) = setup_certificate_chain ( 5 , 2 ) ;
707
+
708
+ let connection = Connection :: open ( ":memory:" ) . unwrap ( ) ;
709
+ setup_certificate_db ( & connection, Vec :: new ( ) ) . unwrap ( ) ;
710
+
711
+ let mut certificate_store_adapter =
712
+ CertificateStoreAdapter :: new ( Arc :: new ( Mutex :: new ( connection) ) ) ;
713
+
714
+ for certificate in & certificates {
715
+ assert ! ( certificate_store_adapter
716
+ . store_record( & certificate. hash, certificate)
717
+ . await
718
+ . is_ok( ) ) ;
719
+ }
720
+
721
+ for certificate in & certificates {
722
+ assert ! ( certificate_store_adapter
723
+ . record_exists( & certificate. hash)
724
+ . await
725
+ . unwrap( ) ) ;
726
+ assert_eq ! (
727
+ Some ( certificate. to_owned( ) ) ,
728
+ certificate_store_adapter
729
+ . get_record( & certificate. hash)
730
+ . await
731
+ . unwrap( )
732
+ ) ;
733
+ }
734
+
735
+ assert_eq ! (
736
+ certificates,
737
+ certificate_store_adapter
738
+ . get_last_n_records( certificates. len( ) )
739
+ . await
740
+ . unwrap( )
741
+ . into_iter( )
742
+ . map( |( _k, v) | v)
743
+ . collect:: <Vec <Certificate >>( )
744
+ )
745
+ }
611
746
}
0 commit comments