Skip to content

Commit 0618412

Browse files
committed
Implement CertificateRecordStore adapter
1 parent 1f1a303 commit 0618412

File tree

1 file changed

+137
-2
lines changed

1 file changed

+137
-2
lines changed

mithril-aggregator/src/database/provider/certificate.rs

Lines changed: 137 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
use std::sync::Arc;
2+
13
use sqlite::{Connection, Value};
24

5+
use async_trait::async_trait;
6+
37
use mithril_common::{
48
entities::{
59
Beacon, Certificate, CertificateMetadata, Epoch, HexEncodedAgregateVerificationKey,
@@ -9,9 +13,11 @@ use mithril_common::{
913
EntityCursor, HydrationError, Projection, Provider, SourceAlias, SqLiteEntity,
1014
WhereCondition,
1115
},
16+
store::adapter::{AdapterError, StoreAdapter},
1217
};
1318

1419
use mithril_common::StdError;
20+
use tokio::sync::Mutex;
1521

1622
/// Certificate record is the representation of a stored certificate.
1723
#[derive(Debug, PartialEq, Clone)]
@@ -288,6 +294,14 @@ impl<'client> CertificateRecordProvider<'client> {
288294

289295
Ok(certificate_record)
290296
}
297+
298+
/// Get all CertificateRecords.
299+
pub fn get_all(&self) -> Result<EntityCursor<CertificateRecord>, StdError> {
300+
let filters = WhereCondition::default();
301+
let certificate_record = self.find(filters)?;
302+
303+
Ok(certificate_record)
304+
}
291305
}
292306

293307
impl<'client> Provider<'client> for CertificateRecordProvider<'client> {
@@ -377,6 +391,81 @@ impl<'conn> Provider<'conn> for InsertCertificateRecordProvider<'conn> {
377391
}
378392
}
379393

394+
/// Service to deal with certificate (read & write).
395+
pub struct CertificateStoreAdapter {
396+
connection: Arc<Mutex<Connection>>,
397+
}
398+
399+
impl CertificateStoreAdapter {
400+
/// Create a new CertificateStoreAdapter service
401+
pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
402+
Self { connection }
403+
}
404+
}
405+
406+
#[async_trait]
407+
impl StoreAdapter for CertificateStoreAdapter {
408+
type Key = String;
409+
type Record = Certificate;
410+
411+
async fn store_record(
412+
&mut self,
413+
_key: &Self::Key,
414+
record: &Self::Record,
415+
) -> Result<(), AdapterError> {
416+
let connection = &*self.connection.lock().await;
417+
let provider = InsertCertificateRecordProvider::new(connection);
418+
let _certificate_record = provider
419+
.persist(record.to_owned().into())
420+
.map_err(|e| AdapterError::GeneralError(format!("{e}")))?;
421+
422+
Ok(())
423+
}
424+
425+
async fn get_record(&self, key: &Self::Key) -> Result<Option<Self::Record>, AdapterError> {
426+
let connection = &*self.connection.lock().await;
427+
let provider = CertificateRecordProvider::new(connection);
428+
let mut cursor = provider
429+
.get_by_certificate_id(key.to_string())
430+
.map_err(|e| AdapterError::GeneralError(format!("{e}")))?;
431+
let certificate = cursor
432+
.next()
433+
.map(|certificate_record| certificate_record.into());
434+
435+
Ok(certificate)
436+
}
437+
438+
async fn record_exists(&self, key: &Self::Key) -> Result<bool, AdapterError> {
439+
Ok(self.get_record(key).await?.is_some())
440+
}
441+
442+
async fn get_last_n_records(
443+
&self,
444+
how_many: usize,
445+
) -> Result<Vec<(Self::Key, Self::Record)>, AdapterError> {
446+
Ok(self
447+
.get_iter()
448+
.await?
449+
.take(how_many)
450+
.map(|c| (c.hash.to_owned(), c))
451+
.collect())
452+
}
453+
454+
async fn remove(&mut self, _key: &Self::Key) -> Result<Option<Self::Record>, AdapterError> {
455+
unimplemented!()
456+
}
457+
458+
async fn get_iter(&self) -> Result<Box<dyn Iterator<Item = Self::Record> + '_>, AdapterError> {
459+
let connection = &*self.connection.lock().await;
460+
let provider = CertificateRecordProvider::new(connection);
461+
let cursor = provider
462+
.get_all()
463+
.map_err(|e| AdapterError::GeneralError(format!("{e}")))?;
464+
let certificates: Vec<Certificate> = cursor.map(|c| c.into()).collect();
465+
Ok(Box::new(certificates.into_iter()))
466+
}
467+
}
468+
380469
#[cfg(test)]
381470
mod tests {
382471
use crate::database::migration::get_migrations;
@@ -573,7 +662,6 @@ mod tests {
573662

574663
let certificate_records: Vec<CertificateRecord> =
575664
provider.get_by_epoch(&Epoch(1)).unwrap().collect();
576-
577665
let expected_certificate_records: Vec<CertificateRecord> = certificates
578666
.iter()
579667
.filter_map(|c| (c.beacon.epoch == Epoch(1)).then_some(c.to_owned().into()))
@@ -582,7 +670,6 @@ mod tests {
582670

583671
let certificate_records: Vec<CertificateRecord> =
584672
provider.get_by_epoch(&Epoch(3)).unwrap().collect();
585-
586673
let expected_certificate_records: Vec<CertificateRecord> = certificates
587674
.iter()
588675
.filter_map(|c| (c.beacon.epoch == Epoch(3)).then_some(c.to_owned().into()))
@@ -591,6 +678,11 @@ mod tests {
591678

592679
let cursor = provider.get_by_epoch(&Epoch(5)).unwrap();
593680
assert_eq!(0, cursor.count());
681+
682+
let certificate_records: Vec<CertificateRecord> = provider.get_all().unwrap().collect();
683+
let expected_certificate_records: Vec<CertificateRecord> =
684+
certificates.iter().map(|c| c.to_owned().into()).collect();
685+
assert_eq!(expected_certificate_records, certificate_records);
594686
}
595687

596688
#[test]
@@ -608,4 +700,47 @@ mod tests {
608700
assert_eq!(certificate_record, certificate_record_saved);
609701
}
610702
}
703+
704+
#[tokio::test]
705+
async fn test_store_adapter() {
706+
let (certificates, _) = setup_certificate_chain(5, 2);
707+
708+
let connection = Connection::open(":memory:").unwrap();
709+
setup_certificate_db(&connection, Vec::new()).unwrap();
710+
711+
let mut certificate_store_adapter =
712+
CertificateStoreAdapter::new(Arc::new(Mutex::new(connection)));
713+
714+
for certificate in &certificates {
715+
assert!(certificate_store_adapter
716+
.store_record(&certificate.hash, certificate)
717+
.await
718+
.is_ok());
719+
}
720+
721+
for certificate in &certificates {
722+
assert!(certificate_store_adapter
723+
.record_exists(&certificate.hash)
724+
.await
725+
.unwrap());
726+
assert_eq!(
727+
Some(certificate.to_owned()),
728+
certificate_store_adapter
729+
.get_record(&certificate.hash)
730+
.await
731+
.unwrap()
732+
);
733+
}
734+
735+
assert_eq!(
736+
certificates,
737+
certificate_store_adapter
738+
.get_last_n_records(certificates.len())
739+
.await
740+
.unwrap()
741+
.into_iter()
742+
.map(|(_k, v)| v)
743+
.collect::<Vec<Certificate>>()
744+
)
745+
}
611746
}

0 commit comments

Comments
 (0)