1
1
use async_trait:: async_trait;
2
2
use chrono:: { DateTime , Utc } ;
3
3
use sqlite:: { Connection , Value } ;
4
+ use std:: iter:: repeat;
4
5
use std:: sync:: Arc ;
5
6
use tokio:: sync:: Mutex ;
6
7
@@ -378,29 +379,52 @@ impl<'conn> InsertCertificateRecordProvider<'conn> {
378
379
}
379
380
380
381
fn get_insert_condition ( & self , certificate_record : & CertificateRecord ) -> WhereCondition {
382
+ self . get_insert_many_condition ( & vec ! [ certificate_record. clone( ) ] )
383
+ }
384
+
385
+ fn get_insert_many_condition (
386
+ & self ,
387
+ certificates_records : & [ CertificateRecord ] ,
388
+ ) -> WhereCondition {
389
+ let columns = "(certificate_id, parent_certificate_id, message, signature, \
390
+ aggregate_verification_key, epoch, beacon, protocol_version, protocol_parameters, \
391
+ protocol_message, signers, initiated_at, sealed_at)";
392
+ let values_columns: Vec < & str > =
393
+ repeat ( "(?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*)" )
394
+ . take ( certificates_records. len ( ) )
395
+ . collect ( ) ;
396
+
397
+ let values: Vec < Value > = certificates_records
398
+ . iter ( )
399
+ . flat_map ( |certificate_record| {
400
+ vec ! [
401
+ Value :: String ( certificate_record. certificate_id. to_owned( ) ) ,
402
+ match certificate_record. parent_certificate_id. to_owned( ) {
403
+ Some ( parent_certificate_id) => Value :: String ( parent_certificate_id) ,
404
+ None => Value :: Null ,
405
+ } ,
406
+ Value :: String ( certificate_record. message. to_owned( ) ) ,
407
+ Value :: String ( certificate_record. signature. to_owned( ) ) ,
408
+ Value :: String ( certificate_record. aggregate_verification_key. to_owned( ) ) ,
409
+ Value :: Integer ( i64 :: try_from( certificate_record. epoch. 0 ) . unwrap( ) ) ,
410
+ Value :: String ( serde_json:: to_string( & certificate_record. beacon) . unwrap( ) ) ,
411
+ Value :: String ( certificate_record. protocol_version. to_owned( ) ) ,
412
+ Value :: String (
413
+ serde_json:: to_string( & certificate_record. protocol_parameters) . unwrap( ) ,
414
+ ) ,
415
+ Value :: String (
416
+ serde_json:: to_string( & certificate_record. protocol_message) . unwrap( ) ,
417
+ ) ,
418
+ Value :: String ( serde_json:: to_string( & certificate_record. signers) . unwrap( ) ) ,
419
+ Value :: String ( certificate_record. initiated_at. to_rfc3339( ) ) ,
420
+ Value :: String ( certificate_record. sealed_at. to_rfc3339( ) ) ,
421
+ ]
422
+ } )
423
+ . collect ( ) ;
424
+
381
425
WhereCondition :: new (
382
- "(certificate_id, parent_certificate_id, message, signature, aggregate_verification_key, epoch, beacon, protocol_version, protocol_parameters, protocol_message, signers, initiated_at, sealed_at) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)" ,
383
- vec ! [
384
- Value :: String ( certificate_record. certificate_id. to_owned( ) ) ,
385
- if let Some ( parent_certificate_id) = certificate_record. parent_certificate_id. to_owned( ) {
386
- Value :: String ( parent_certificate_id)
387
- } else{
388
- Value :: Null
389
- } ,
390
- Value :: String ( certificate_record. message. to_owned( ) ) ,
391
- Value :: String ( certificate_record. signature. to_owned( ) ) ,
392
- Value :: String ( certificate_record. aggregate_verification_key. to_owned( ) ) ,
393
- Value :: Integer ( i64 :: try_from( certificate_record. epoch. 0 ) . unwrap( ) ) ,
394
- Value :: String ( serde_json:: to_string( & certificate_record. beacon) . unwrap( ) ) ,
395
- Value :: String ( certificate_record. protocol_version. to_owned( ) ) ,
396
- Value :: String (
397
- serde_json:: to_string( & certificate_record. protocol_parameters) . unwrap( ) ,
398
- ) ,
399
- Value :: String ( serde_json:: to_string( & certificate_record. protocol_message) . unwrap( ) ) ,
400
- Value :: String ( serde_json:: to_string( & certificate_record. signers) . unwrap( ) ) ,
401
- Value :: String ( certificate_record. initiated_at. to_rfc3339( ) ) ,
402
- Value :: String ( certificate_record. sealed_at. to_rfc3339( ) ) ,
403
- ] ,
426
+ format ! ( "{columns} values {}" , values_columns. join( ", " ) ) . as_str ( ) ,
427
+ values,
404
428
)
405
429
}
406
430
@@ -418,6 +442,15 @@ impl<'conn> InsertCertificateRecordProvider<'conn> {
418
442
419
443
Ok ( entity)
420
444
}
445
+
446
+ fn persist_many (
447
+ & self ,
448
+ certificates_records : Vec < CertificateRecord > ,
449
+ ) -> Result < Vec < CertificateRecord > , StdError > {
450
+ let filters = self . get_insert_many_condition ( & certificates_records) ;
451
+
452
+ Ok ( self . find ( filters) ?. collect ( ) )
453
+ }
421
454
}
422
455
423
456
impl < ' conn > Provider < ' conn > for InsertCertificateRecordProvider < ' conn > {
@@ -587,6 +620,23 @@ impl CertificateRepository {
587
620
Ok ( new_certificate. into ( ) )
588
621
}
589
622
623
+ /// Create many certificates at once in the database.
624
+ pub async fn create_many_certificates (
625
+ & self ,
626
+ certificates : Vec < Certificate > ,
627
+ ) -> StdResult < Vec < Certificate > > {
628
+ let lock = self . connection . lock ( ) . await ;
629
+ let provider = InsertCertificateRecordProvider :: new ( & lock) ;
630
+ let records: Vec < CertificateRecord > =
631
+ certificates. into_iter ( ) . map ( |cert| cert. into ( ) ) . collect ( ) ;
632
+ let new_certificates = provider. persist_many ( records) ?;
633
+
634
+ Ok ( new_certificates
635
+ . into_iter ( )
636
+ . map ( |cert| cert. into ( ) )
637
+ . collect :: < Vec < _ > > ( ) )
638
+ }
639
+
590
640
/// Delete all the given certificates from the database
591
641
pub async fn delete_certificates ( & self , certificates : & [ & Certificate ] ) -> StdResult < ( ) > {
592
642
let ids = certificates
@@ -856,7 +906,7 @@ mod tests {
856
906
}
857
907
858
908
#[ test]
859
- fn insert_certificate_record ( ) {
909
+ fn insert_certificate_condition ( ) {
860
910
let ( certificates, _) = setup_certificate_chain ( 2 , 1 ) ;
861
911
let certificate_record: CertificateRecord = certificates. first ( ) . unwrap ( ) . to_owned ( ) . into ( ) ;
862
912
let connection = Connection :: open ( ":memory:" ) . unwrap ( ) ;
@@ -890,6 +940,57 @@ mod tests {
890
940
) ;
891
941
}
892
942
943
+ #[ test]
944
+ fn insert_many_certificates_condition ( ) {
945
+ let ( certificates, _) = setup_certificate_chain ( 2 , 1 ) ;
946
+ let certificates_records: Vec < CertificateRecord > =
947
+ certificates. into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
948
+ let connection = Connection :: open ( ":memory:" ) . unwrap ( ) ;
949
+ let provider = InsertCertificateRecordProvider :: new ( & connection) ;
950
+ let condition = provider. get_insert_many_condition ( & certificates_records) ;
951
+ let ( values, params) = condition. expand ( ) ;
952
+
953
+ assert_eq ! (
954
+ "(certificate_id, parent_certificate_id, message, signature, \
955
+ aggregate_verification_key, epoch, beacon, protocol_version, protocol_parameters, \
956
+ protocol_message, signers, initiated_at, sealed_at) values \
957
+ (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13), \
958
+ (?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23, ?24, ?25, ?26)"
959
+ . to_string( ) ,
960
+ values
961
+ ) ;
962
+ assert_eq ! (
963
+ certificates_records
964
+ . into_iter( )
965
+ . flat_map( |certificate_record| {
966
+ vec![
967
+ Value :: String ( certificate_record. certificate_id) ,
968
+ match certificate_record. parent_certificate_id {
969
+ Some ( id) => Value :: String ( id) ,
970
+ None => Value :: Null ,
971
+ } ,
972
+ Value :: String ( certificate_record. message) ,
973
+ Value :: String ( certificate_record. signature) ,
974
+ Value :: String ( certificate_record. aggregate_verification_key) ,
975
+ Value :: Integer ( i64 :: try_from( certificate_record. epoch. 0 ) . unwrap( ) ) ,
976
+ Value :: String ( serde_json:: to_string( & certificate_record. beacon) . unwrap( ) ) ,
977
+ Value :: String ( certificate_record. protocol_version) ,
978
+ Value :: String (
979
+ serde_json:: to_string( & certificate_record. protocol_parameters) . unwrap( ) ,
980
+ ) ,
981
+ Value :: String (
982
+ serde_json:: to_string( & certificate_record. protocol_message) . unwrap( ) ,
983
+ ) ,
984
+ Value :: String ( serde_json:: to_string( & certificate_record. signers) . unwrap( ) ) ,
985
+ Value :: String ( certificate_record. initiated_at. to_rfc3339( ) ) ,
986
+ Value :: String ( certificate_record. sealed_at. to_rfc3339( ) ) ,
987
+ ]
988
+ } )
989
+ . collect:: <Vec <_>>( ) ,
990
+ params
991
+ ) ;
992
+ }
993
+
893
994
#[ test]
894
995
fn test_get_certificate_records ( ) {
895
996
let ( certificates, _) = setup_certificate_chain ( 20 , 7 ) ;
@@ -945,6 +1046,23 @@ mod tests {
945
1046
}
946
1047
}
947
1048
1049
+ #[ test]
1050
+ fn test_insert_many_certificates_records ( ) {
1051
+ let ( certificates, _) = setup_certificate_chain ( 5 , 2 ) ;
1052
+ let certificates_records: Vec < CertificateRecord > =
1053
+ certificates. into_iter ( ) . map ( |cert| cert. into ( ) ) . collect ( ) ;
1054
+
1055
+ let connection = Connection :: open ( ":memory:" ) . unwrap ( ) ;
1056
+ setup_certificate_db ( & connection, Vec :: new ( ) ) . unwrap ( ) ;
1057
+
1058
+ let provider = InsertCertificateRecordProvider :: new ( & connection) ;
1059
+ let certificates_records_saved = provider
1060
+ . persist_many ( certificates_records. clone ( ) )
1061
+ . expect ( "saving many records should not fail" ) ;
1062
+
1063
+ assert_eq ! ( certificates_records, certificates_records_saved) ;
1064
+ }
1065
+
948
1066
#[ tokio:: test]
949
1067
async fn test_store_adapter ( ) {
950
1068
let ( certificates, _) = setup_certificate_chain ( 5 , 2 ) ;
0 commit comments