Skip to content

Commit 5449440

Browse files
committed
Rename types & methods to discriminate more between signer import and registration
1 parent cf16c42 commit 5449440

File tree

4 files changed

+131
-78
lines changed

4 files changed

+131
-78
lines changed

mithril-aggregator/src/database/provider/signer.rs

Lines changed: 122 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,17 @@ impl<'client> Provider<'client> for SignerRecordProvider<'client> {
135135
}
136136

137137
/// Query to insert the signer record
138-
pub struct InsertSignerRecordProvider<'conn> {
138+
pub struct RegisterSignerRecordProvider<'conn> {
139139
connection: &'conn Connection,
140140
}
141141

142-
impl<'conn> InsertSignerRecordProvider<'conn> {
142+
impl<'conn> RegisterSignerRecordProvider<'conn> {
143143
/// Create a new instance
144144
pub fn new(connection: &'conn Connection) -> Self {
145145
Self { connection }
146146
}
147147

148-
fn get_insert_condition(&self, signer_record: SignerRecord) -> WhereCondition {
148+
fn get_register_condition(&self, signer_record: SignerRecord) -> WhereCondition {
149149
WhereCondition::new(
150150
"(signer_id, pool_ticker, created_at, updated_at, registered_at) values (?*, ?*, ?*, ?*, ?*)",
151151
vec![
@@ -165,7 +165,7 @@ impl<'conn> InsertSignerRecordProvider<'conn> {
165165
}
166166

167167
fn persist(&self, signer_record: SignerRecord) -> StdResult<SignerRecord> {
168-
let filters = self.get_insert_condition(signer_record.clone());
168+
let filters = self.get_register_condition(signer_record.clone());
169169

170170
let entity = self.find(filters)?.next().unwrap_or_else(|| {
171171
panic!("No entity returned by the persister, signer_record = {signer_record:?}")
@@ -175,7 +175,7 @@ impl<'conn> InsertSignerRecordProvider<'conn> {
175175
}
176176
}
177177

178-
impl<'conn> Provider<'conn> for InsertSignerRecordProvider<'conn> {
178+
impl<'conn> Provider<'conn> for RegisterSignerRecordProvider<'conn> {
179179
type Entity = SignerRecord;
180180

181181
fn get_connection(&'conn self) -> &'conn Connection {
@@ -188,22 +188,25 @@ impl<'conn> Provider<'conn> for InsertSignerRecordProvider<'conn> {
188188
let projection =
189189
Self::Entity::get_projection().expand(SourceAlias::new(&[("{:signer:}", "signer")]));
190190

191-
format!("insert into signer {condition} on conflict (signer_id) do update set updated_at = excluded.updated_at returning {projection}")
191+
format!(
192+
"insert into signer {condition} on conflict (signer_id) do update set \
193+
updated_at = excluded.updated_at, registered_at = excluded.registered_at returning {projection}"
194+
)
192195
}
193196
}
194197

195198
/// Query to update the signer record
196-
pub struct UpdateSignerRecordProvider<'conn> {
199+
pub struct ImportSignerRecordProvider<'conn> {
197200
connection: &'conn Connection,
198201
}
199202

200-
impl<'conn> UpdateSignerRecordProvider<'conn> {
203+
impl<'conn> ImportSignerRecordProvider<'conn> {
201204
/// Create a new instance
202205
pub fn new(connection: &'conn Connection) -> Self {
203206
Self { connection }
204207
}
205208

206-
fn get_update_condition(&self, signer_records: Vec<SignerRecord>) -> WhereCondition {
209+
fn get_import_condition(&self, signer_records: Vec<SignerRecord>) -> WhereCondition {
207210
let columns = "(signer_id, pool_ticker, created_at, updated_at, registered_at)";
208211
let values_columns: Vec<&str> = repeat("(?*, ?*, ?*, ?*, ?*)")
209212
.take(signer_records.len())
@@ -234,7 +237,7 @@ impl<'conn> UpdateSignerRecordProvider<'conn> {
234237
}
235238

236239
fn persist(&self, signer_record: SignerRecord) -> StdResult<SignerRecord> {
237-
let filters = self.get_update_condition(vec![signer_record.clone()]);
240+
let filters = self.get_import_condition(vec![signer_record.clone()]);
238241

239242
let entity = self.find(filters)?.next().unwrap_or_else(|| {
240243
panic!("No entity returned by the persister, signer_record = {signer_record:?}")
@@ -244,13 +247,13 @@ impl<'conn> UpdateSignerRecordProvider<'conn> {
244247
}
245248

246249
fn persist_many(&self, signer_records: Vec<SignerRecord>) -> StdResult<Vec<SignerRecord>> {
247-
let filters = self.get_update_condition(signer_records);
250+
let filters = self.get_import_condition(signer_records);
248251

249252
Ok(self.find(filters)?.collect())
250253
}
251254
}
252255

253-
impl<'conn> Provider<'conn> for UpdateSignerRecordProvider<'conn> {
256+
impl<'conn> Provider<'conn> for ImportSignerRecordProvider<'conn> {
254257
type Entity = SignerRecord;
255258

256259
fn get_connection(&'conn self) -> &'conn Connection {
@@ -263,7 +266,10 @@ impl<'conn> Provider<'conn> for UpdateSignerRecordProvider<'conn> {
263266
let projection =
264267
Self::Entity::get_projection().expand(SourceAlias::new(&[("{:signer:}", "signer")]));
265268

266-
format!("insert into signer {condition} on conflict(signer_id) do update set pool_ticker = excluded.pool_ticker, updated_at = excluded.updated_at returning {projection}")
269+
format!(
270+
"insert into signer {condition} on conflict(signer_id) do update \
271+
set pool_ticker = excluded.pool_ticker, updated_at = excluded.updated_at returning {projection}"
272+
)
267273
}
268274
}
269275

@@ -286,34 +292,15 @@ impl SignerStore {
286292

287293
Ok(cursor.collect())
288294
}
289-
}
290295

291-
#[async_trait]
292-
impl SignerRecorder for SignerStore {
293-
async fn record_signer_id(&self, signer_id: String) -> StdResult<()> {
294-
let connection = &*self.connection.lock().await;
295-
let provider = InsertSignerRecordProvider::new(connection);
296-
let created_at = Utc::now();
297-
let updated_at = created_at;
298-
let signer_record = SignerRecord {
299-
signer_id,
300-
pool_ticker: None,
301-
created_at,
302-
updated_at,
303-
registered_at: None,
304-
};
305-
provider.persist(signer_record)?;
306-
307-
Ok(())
308-
}
309-
310-
async fn record_signer_pool_ticker(
296+
/// Import a signer in the database, its registered_at date will be left empty
297+
pub async fn import_signer(
311298
&self,
312299
signer_id: String,
313300
pool_ticker: Option<String>,
314301
) -> StdResult<()> {
315302
let connection = &*self.connection.lock().await;
316-
let provider = UpdateSignerRecordProvider::new(connection);
303+
let provider = ImportSignerRecordProvider::new(connection);
317304
let created_at = Utc::now();
318305
let updated_at = created_at;
319306
let signer_record = SignerRecord {
@@ -328,12 +315,13 @@ impl SignerRecorder for SignerStore {
328315
Ok(())
329316
}
330317

331-
async fn record_many_signers_pool_tickers(
318+
/// Create many signers at once in the database, their registered_at date will be left empty
319+
pub async fn import_many_signers(
332320
&self,
333321
pool_ticker_by_id: HashMap<String, Option<String>>,
334322
) -> StdResult<()> {
335323
let connection = &*self.connection.lock().await;
336-
let provider = UpdateSignerRecordProvider::new(connection);
324+
let provider = ImportSignerRecordProvider::new(connection);
337325

338326
let created_at = Utc::now();
339327
let updated_at = created_at;
@@ -354,11 +342,33 @@ impl SignerRecorder for SignerStore {
354342
}
355343
}
356344

345+
#[async_trait]
346+
impl SignerRecorder for SignerStore {
347+
async fn record_signer_registration(&self, signer_id: String) -> StdResult<()> {
348+
let connection = &*self.connection.lock().await;
349+
let provider = RegisterSignerRecordProvider::new(connection);
350+
let created_at = Utc::now();
351+
let updated_at = created_at;
352+
let registered_at = Some(created_at);
353+
let signer_record = SignerRecord {
354+
signer_id,
355+
pool_ticker: None,
356+
created_at,
357+
updated_at,
358+
registered_at,
359+
};
360+
provider.persist(signer_record)?;
361+
362+
Ok(())
363+
}
364+
}
365+
357366
#[cfg(test)]
358367
mod tests {
359368
use crate::database::provider::apply_all_migrations_to_db;
360369
use chrono::Duration;
361370
use mithril_common::StdResult;
371+
use std::collections::BTreeMap;
362372

363373
use super::*;
364374

@@ -395,9 +405,9 @@ mod tests {
395405
let query = {
396406
// leverage the expanded parameter from this provider which is unit
397407
// tested on its own above.
398-
let update_provider = UpdateSignerRecordProvider::new(connection);
408+
let update_provider = ImportSignerRecordProvider::new(connection);
399409
let (sql_values, _) = update_provider
400-
.get_update_condition(vec![signer_records.first().unwrap().to_owned()])
410+
.get_import_condition(vec![signer_records.first().unwrap().to_owned()])
401411
.expand();
402412
format!("insert into signer {sql_values}")
403413
};
@@ -461,8 +471,8 @@ mod tests {
461471
fn insert_signer_record() {
462472
let signer_record = fake_signer_records(1).first().unwrap().to_owned();
463473
let connection = Connection::open(":memory:").unwrap();
464-
let provider = InsertSignerRecordProvider::new(&connection);
465-
let condition = provider.get_insert_condition(signer_record.clone());
474+
let provider = RegisterSignerRecordProvider::new(&connection);
475+
let condition = provider.get_register_condition(signer_record.clone());
466476
let (values, params) = condition.expand();
467477

468478
assert_eq!(
@@ -485,8 +495,8 @@ mod tests {
485495
fn update_signer_record() {
486496
let signer_records = fake_signer_records(2);
487497
let connection = Connection::open(":memory:").unwrap();
488-
let provider = UpdateSignerRecordProvider::new(&connection);
489-
let condition = provider.get_update_condition(signer_records.clone());
498+
let provider = ImportSignerRecordProvider::new(&connection);
499+
let condition = provider.get_import_condition(signer_records.clone());
490500
let (values, params) = condition.expand();
491501

492502
assert_eq!(
@@ -551,7 +561,7 @@ mod tests {
551561
let connection = Connection::open(":memory:").unwrap();
552562
setup_signer_db(&connection, Vec::new()).unwrap();
553563

554-
let provider = InsertSignerRecordProvider::new(&connection);
564+
let provider = RegisterSignerRecordProvider::new(&connection);
555565

556566
for signer_record in signer_records_fake.clone() {
557567
let signer_record_saved = provider.persist(signer_record.clone()).unwrap();
@@ -572,7 +582,7 @@ mod tests {
572582
let connection = Connection::open(":memory:").unwrap();
573583
setup_signer_db(&connection, signer_records_fake.clone()).unwrap();
574584

575-
let provider = UpdateSignerRecordProvider::new(&connection);
585+
let provider = ImportSignerRecordProvider::new(&connection);
576586

577587
for signer_record in signer_records_fake.clone() {
578588
let signer_record_saved = provider.persist(signer_record.clone()).unwrap();
@@ -595,7 +605,7 @@ mod tests {
595605
let connection = Connection::open(":memory:").unwrap();
596606
setup_signer_db(&connection, signer_records_fake.clone()).unwrap();
597607

598-
let provider = UpdateSignerRecordProvider::new(&connection);
608+
let provider = ImportSignerRecordProvider::new(&connection);
599609
let mut saved_records = provider.persist_many(signer_records_fake.clone()).unwrap();
600610
saved_records.sort_by(|a, b| a.signer_id.cmp(&b.signer_id));
601611
assert_eq!(signer_records_fake, saved_records);
@@ -638,31 +648,90 @@ mod tests {
638648

639649
for signer_record in signer_records_fake.clone() {
640650
store_recorder
641-
.record_signer_id(signer_record.signer_id.clone())
651+
.record_signer_registration(signer_record.signer_id.clone())
642652
.await
643-
.expect("record_signer_id should not fail");
653+
.expect("record_signer_registration should not fail");
644654
let connection = &*connection.lock().await;
645655
let provider = SignerRecordProvider::new(connection);
646656
let signer_records_stored: Vec<SignerRecord> = provider
647657
.get_by_signer_id(signer_record.signer_id)
648658
.unwrap()
649659
.collect::<Vec<_>>();
650660
assert_eq!(1, signer_records_stored.len());
661+
assert!(
662+
signer_records_stored
663+
.iter()
664+
.all(|s| s.registered_at.is_some()),
665+
"registering a signer should set the registration date"
666+
)
651667
}
668+
}
669+
670+
#[tokio::test]
671+
async fn test_store_import_signer() {
672+
let signer_records_fake = fake_signer_records(5);
673+
674+
let connection = Connection::open(":memory:").unwrap();
675+
setup_signer_db(&connection, Vec::new()).unwrap();
676+
677+
let connection = Arc::new(Mutex::new(connection));
678+
let store = SignerStore::new(connection.clone());
652679

653680
for signer_record in signer_records_fake {
654-
let pool_ticker = Some(format!("new-pool-{}", signer_record.signer_id));
655-
store_recorder
656-
.record_signer_pool_ticker(signer_record.signer_id.clone(), pool_ticker.clone())
681+
store
682+
.import_signer(
683+
signer_record.signer_id.clone(),
684+
signer_record.pool_ticker.clone(),
685+
)
657686
.await
658-
.expect("record_signer_pool_ticker should not fail");
687+
.expect("import_signer should not fail");
659688
let connection = &*connection.lock().await;
660689
let provider = SignerRecordProvider::new(connection);
661690
let signer_records_stored: Vec<SignerRecord> = provider
662691
.get_by_signer_id(signer_record.signer_id)
663692
.unwrap()
664693
.collect::<Vec<_>>();
665-
assert_eq!(pool_ticker, signer_records_stored[0].to_owned().pool_ticker);
694+
assert_eq!(
695+
signer_record.pool_ticker,
696+
signer_records_stored[0].to_owned().pool_ticker
697+
);
698+
assert!(
699+
signer_records_stored
700+
.iter()
701+
.all(|s| s.registered_at.is_none()),
702+
"imported signer should not have a registration date"
703+
)
666704
}
667705
}
706+
707+
#[tokio::test]
708+
async fn test_store_import_many_signers() {
709+
let signers_fake: BTreeMap<_, _> = fake_signer_records(5)
710+
.into_iter()
711+
.map(|r| (r.signer_id, r.pool_ticker))
712+
.collect();
713+
714+
let connection = Connection::open(":memory:").unwrap();
715+
setup_signer_db(&connection, Vec::new()).unwrap();
716+
let store = SignerStore::new(Arc::new(Mutex::new(connection)));
717+
718+
store
719+
.import_many_signers(signers_fake.clone().into_iter().collect())
720+
.await
721+
.expect("import_many_signers should not fail");
722+
723+
let signer_records_stored = store.get_all().await.unwrap();
724+
let signers_stored = signer_records_stored
725+
.iter()
726+
.cloned()
727+
.map(|r| (r.signer_id, r.pool_ticker))
728+
.collect();
729+
assert_eq!(signers_fake, signers_stored);
730+
assert!(
731+
signer_records_stored
732+
.iter()
733+
.all(|s| s.registered_at.is_none()),
734+
"imported signer should not have a registration date"
735+
);
736+
}
668737
}

mithril-aggregator/src/dependency_injection/containers.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@ impl DependencyContainer {
285285
async fn fill_verification_key_store(&self, target_epoch: Epoch, signers: &[SignerWithStake]) {
286286
for signer in signers {
287287
self.signer_recorder
288-
.record_signer_id(signer.party_id.clone())
288+
.record_signer_registration(signer.party_id.clone())
289289
.await
290-
.expect("record_signer_id should not fail");
290+
.expect("record_signer_registration should not fail");
291291
self.verification_key_store
292292
.save_verification_key(target_epoch, signer.clone())
293293
.await

0 commit comments

Comments
 (0)