Skip to content

Commit 13e67be

Browse files
committed
Add a create_many_certificates on the repository to allow bulk insert
This speed up certificates insertion on a db with 1800+ certificates by a factor of 5+ (8s instead of 44s).
1 parent 335ea99 commit 13e67be

File tree

2 files changed

+144
-34
lines changed

2 files changed

+144
-34
lines changed

mithril-aggregator/src/database/provider/certificate.rs

Lines changed: 141 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use async_trait::async_trait;
22
use chrono::{DateTime, Utc};
33
use sqlite::{Connection, Value};
4+
use std::iter::repeat;
45
use std::sync::Arc;
56
use tokio::sync::Mutex;
67

@@ -378,29 +379,52 @@ impl<'conn> InsertCertificateRecordProvider<'conn> {
378379
}
379380

380381
fn get_insert_condition(&self, certificate_record: &CertificateRecord) -> WhereCondition {
382+
self.get_insert_many_condition(&vec![certificate_record.clone()])
383+
}
384+
385+
fn get_insert_many_condition(
386+
&self,
387+
certificates_records: &[CertificateRecord],
388+
) -> WhereCondition {
389+
let columns = "(certificate_id, parent_certificate_id, message, signature, \
390+
aggregate_verification_key, epoch, beacon, protocol_version, protocol_parameters, \
391+
protocol_message, signers, initiated_at, sealed_at)";
392+
let values_columns: Vec<&str> =
393+
repeat("(?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*, ?*)")
394+
.take(certificates_records.len())
395+
.collect();
396+
397+
let values: Vec<Value> = certificates_records
398+
.iter()
399+
.flat_map(|certificate_record| {
400+
vec![
401+
Value::String(certificate_record.certificate_id.to_owned()),
402+
match certificate_record.parent_certificate_id.to_owned() {
403+
Some(parent_certificate_id) => Value::String(parent_certificate_id),
404+
None => Value::Null,
405+
},
406+
Value::String(certificate_record.message.to_owned()),
407+
Value::String(certificate_record.signature.to_owned()),
408+
Value::String(certificate_record.aggregate_verification_key.to_owned()),
409+
Value::Integer(i64::try_from(certificate_record.epoch.0).unwrap()),
410+
Value::String(serde_json::to_string(&certificate_record.beacon).unwrap()),
411+
Value::String(certificate_record.protocol_version.to_owned()),
412+
Value::String(
413+
serde_json::to_string(&certificate_record.protocol_parameters).unwrap(),
414+
),
415+
Value::String(
416+
serde_json::to_string(&certificate_record.protocol_message).unwrap(),
417+
),
418+
Value::String(serde_json::to_string(&certificate_record.signers).unwrap()),
419+
Value::String(certificate_record.initiated_at.to_rfc3339()),
420+
Value::String(certificate_record.sealed_at.to_rfc3339()),
421+
]
422+
})
423+
.collect();
424+
381425
WhereCondition::new(
382-
"(certificate_id, parent_certificate_id, message, signature, aggregate_verification_key, epoch, beacon, protocol_version, protocol_parameters, protocol_message, signers, initiated_at, sealed_at) values (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13)",
383-
vec![
384-
Value::String(certificate_record.certificate_id.to_owned()),
385-
if let Some(parent_certificate_id) = certificate_record.parent_certificate_id.to_owned() {
386-
Value::String(parent_certificate_id)
387-
}else{
388-
Value::Null
389-
},
390-
Value::String(certificate_record.message.to_owned()),
391-
Value::String(certificate_record.signature.to_owned()),
392-
Value::String(certificate_record.aggregate_verification_key.to_owned()),
393-
Value::Integer(i64::try_from(certificate_record.epoch.0).unwrap()),
394-
Value::String(serde_json::to_string(&certificate_record.beacon).unwrap()),
395-
Value::String(certificate_record.protocol_version.to_owned()),
396-
Value::String(
397-
serde_json::to_string(&certificate_record.protocol_parameters).unwrap(),
398-
),
399-
Value::String(serde_json::to_string(&certificate_record.protocol_message).unwrap()),
400-
Value::String(serde_json::to_string(&certificate_record.signers).unwrap()),
401-
Value::String(certificate_record.initiated_at.to_rfc3339()),
402-
Value::String(certificate_record.sealed_at.to_rfc3339()),
403-
],
426+
format!("{columns} values {}", values_columns.join(", ")).as_str(),
427+
values,
404428
)
405429
}
406430

@@ -418,6 +442,15 @@ impl<'conn> InsertCertificateRecordProvider<'conn> {
418442

419443
Ok(entity)
420444
}
445+
446+
fn persist_many(
447+
&self,
448+
certificates_records: Vec<CertificateRecord>,
449+
) -> Result<Vec<CertificateRecord>, StdError> {
450+
let filters = self.get_insert_many_condition(&certificates_records);
451+
452+
Ok(self.find(filters)?.collect())
453+
}
421454
}
422455

423456
impl<'conn> Provider<'conn> for InsertCertificateRecordProvider<'conn> {
@@ -587,6 +620,23 @@ impl CertificateRepository {
587620
Ok(new_certificate.into())
588621
}
589622

623+
/// Create many certificates at once in the database.
624+
pub async fn create_many_certificates(
625+
&self,
626+
certificates: Vec<Certificate>,
627+
) -> StdResult<Vec<Certificate>> {
628+
let lock = self.connection.lock().await;
629+
let provider = InsertCertificateRecordProvider::new(&lock);
630+
let records: Vec<CertificateRecord> =
631+
certificates.into_iter().map(|cert| cert.into()).collect();
632+
let new_certificates = provider.persist_many(records)?;
633+
634+
Ok(new_certificates
635+
.into_iter()
636+
.map(|cert| cert.into())
637+
.collect::<Vec<_>>())
638+
}
639+
590640
/// Delete all the given certificates from the database
591641
pub async fn delete_certificates(&self, certificates: &[&Certificate]) -> StdResult<()> {
592642
let ids = certificates
@@ -856,7 +906,7 @@ mod tests {
856906
}
857907

858908
#[test]
859-
fn insert_certificate_record() {
909+
fn insert_certificate_condition() {
860910
let (certificates, _) = setup_certificate_chain(2, 1);
861911
let certificate_record: CertificateRecord = certificates.first().unwrap().to_owned().into();
862912
let connection = Connection::open(":memory:").unwrap();
@@ -890,6 +940,57 @@ mod tests {
890940
);
891941
}
892942

943+
#[test]
944+
fn insert_many_certificates_condition() {
945+
let (certificates, _) = setup_certificate_chain(2, 1);
946+
let certificates_records: Vec<CertificateRecord> =
947+
certificates.into_iter().map(|c| c.into()).collect();
948+
let connection = Connection::open(":memory:").unwrap();
949+
let provider = InsertCertificateRecordProvider::new(&connection);
950+
let condition = provider.get_insert_many_condition(&certificates_records);
951+
let (values, params) = condition.expand();
952+
953+
assert_eq!(
954+
"(certificate_id, parent_certificate_id, message, signature, \
955+
aggregate_verification_key, epoch, beacon, protocol_version, protocol_parameters, \
956+
protocol_message, signers, initiated_at, sealed_at) values \
957+
(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13), \
958+
(?14, ?15, ?16, ?17, ?18, ?19, ?20, ?21, ?22, ?23, ?24, ?25, ?26)"
959+
.to_string(),
960+
values
961+
);
962+
assert_eq!(
963+
certificates_records
964+
.into_iter()
965+
.flat_map(|certificate_record| {
966+
vec![
967+
Value::String(certificate_record.certificate_id),
968+
match certificate_record.parent_certificate_id {
969+
Some(id) => Value::String(id),
970+
None => Value::Null,
971+
},
972+
Value::String(certificate_record.message),
973+
Value::String(certificate_record.signature),
974+
Value::String(certificate_record.aggregate_verification_key),
975+
Value::Integer(i64::try_from(certificate_record.epoch.0).unwrap()),
976+
Value::String(serde_json::to_string(&certificate_record.beacon).unwrap()),
977+
Value::String(certificate_record.protocol_version),
978+
Value::String(
979+
serde_json::to_string(&certificate_record.protocol_parameters).unwrap(),
980+
),
981+
Value::String(
982+
serde_json::to_string(&certificate_record.protocol_message).unwrap(),
983+
),
984+
Value::String(serde_json::to_string(&certificate_record.signers).unwrap()),
985+
Value::String(certificate_record.initiated_at.to_rfc3339()),
986+
Value::String(certificate_record.sealed_at.to_rfc3339()),
987+
]
988+
})
989+
.collect::<Vec<_>>(),
990+
params
991+
);
992+
}
993+
893994
#[test]
894995
fn test_get_certificate_records() {
895996
let (certificates, _) = setup_certificate_chain(20, 7);
@@ -945,6 +1046,23 @@ mod tests {
9451046
}
9461047
}
9471048

1049+
#[test]
1050+
fn test_insert_many_certificates_records() {
1051+
let (certificates, _) = setup_certificate_chain(5, 2);
1052+
let certificates_records: Vec<CertificateRecord> =
1053+
certificates.into_iter().map(|cert| cert.into()).collect();
1054+
1055+
let connection = Connection::open(":memory:").unwrap();
1056+
setup_certificate_db(&connection, Vec::new()).unwrap();
1057+
1058+
let provider = InsertCertificateRecordProvider::new(&connection);
1059+
let certificates_records_saved = provider
1060+
.persist_many(certificates_records.clone())
1061+
.expect("saving many records should not fail");
1062+
1063+
assert_eq!(certificates_records, certificates_records_saved);
1064+
}
1065+
9481066
#[tokio::test]
9491067
async fn test_store_adapter() {
9501068
let (certificates, _) = setup_certificate_chain(5, 2);

mithril-aggregator/src/tools/certificates_hash_migrator.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,9 @@ impl CertificatesHashMigrator {
107107

108108
// 2 - Certificates migrated, we can insert them in the db
109109
debug!("🔧 Certificate Hash Migrator: inserting migrated certificates in the database");
110-
for migrated_certificate in migrated_certificates {
111-
trace!(
112-
"🔧 Certificate Hash Migrator: inserting migrated certificate {:?}",
113-
migrated_certificate.beacon;
114-
"hash" => &migrated_certificate.hash,
115-
"previous_hash" => &migrated_certificate.previous_hash
116-
);
117-
self.certificate_repository
118-
.create_certificate(migrated_certificate)
119-
.await?;
120-
}
110+
self.certificate_repository
111+
.create_many_certificates(migrated_certificates)
112+
.await?;
121113

122114
Ok((old_certificates, old_and_new_hashes))
123115
}

0 commit comments

Comments
 (0)