Skip to content

Commit fba460a

Browse files
committed
add stake distribution service
1 parent b9dee08 commit fba460a

File tree

11 files changed

+428
-119
lines changed

11 files changed

+428
-119
lines changed

mithril-aggregator/src/database/provider/stake_pool.rs

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
use std::{
2-
collections::HashMap,
3-
sync::{Arc, Mutex},
4-
};
1+
use std::sync::{Arc, Mutex};
52

63
use async_trait::async_trait;
74
use chrono::NaiveDateTime;
@@ -121,7 +118,7 @@ impl<'client> Provider<'client> for StakePoolProvider<'client> {
121118
let aliases = SourceAlias::new(&[("{:stake_pool:}", "sp")]);
122119
let projection = Self::Entity::get_projection().expand(aliases);
123120

124-
format!("select {projection} from stake_pool as sp where {condition} order by epoch asc, stake desc")
121+
format!("select {projection} from stake_pool as sp where {condition} order by epoch asc, stake desc, stake_pool_id asc")
125122
}
126123
}
127124

@@ -254,7 +251,7 @@ impl StakeStorer for StakePoolStore {
254251
.lock()
255252
.map_err(|e| AdapterError::GeneralError(format!("{e}")))?;
256253
let provider = UpdateStakePoolProvider::new(connection);
257-
let mut new_stakes: HashMap<PartyId, Stake> = HashMap::new();
254+
let mut new_stakes = StakeDistribution::new();
258255
connection
259256
.execute("begin transaction")
260257
.map_err(|e| AdapterError::QueryError(e.into()))?;
@@ -287,7 +284,7 @@ impl StakeStorer for StakePoolStore {
287284
let cursor = provider
288285
.get_by_epoch(&epoch)
289286
.map_err(|e| AdapterError::GeneralError(format!("Could not get stakes: {e}")))?;
290-
let mut stake_distribution: HashMap<PartyId, Stake> = HashMap::new();
287+
let mut stake_distribution = StakeDistribution::new();
291288

292289
for stake_pool in cursor {
293290
stake_distribution.insert(stake_pool.stake_pool_id, stake_pool.stake);
@@ -297,10 +294,54 @@ impl StakeStorer for StakePoolStore {
297294
}
298295
}
299296

300-
#[cfg(test)]
301-
mod tests {
297+
#[cfg(any(test, feature = "test_only"))]
298+
pub fn setup_stake_db(connection: &Connection) -> Result<(), StdError> {
302299
use crate::database::migration::get_migrations;
303300

301+
let migrations = get_migrations();
302+
let migration = migrations
303+
.iter()
304+
.find(|&m| m.version == 1)
305+
.ok_or_else(|| -> StdError {
306+
"There should be a migration version 1".to_string().into()
307+
})?;
308+
let query = {
309+
// leverage the expanded parameter from this provider which is unit
310+
// tested on its own above.
311+
let update_provider = UpdateStakePoolProvider::new(connection);
312+
let (sql_values, _) = update_provider
313+
.get_update_condition("pool_id", Epoch(1), 1000)
314+
.expand();
315+
316+
connection.execute(&migration.alterations)?;
317+
318+
format!("insert into stake_pool {sql_values}")
319+
};
320+
let stake_distribution: &[(&str, i64, i64); 9] = &[
321+
("pool1", 1, 1000),
322+
("pool2", 1, 1100),
323+
("pool3", 1, 1300),
324+
("pool1", 2, 1230),
325+
("pool2", 2, 1090),
326+
("pool3", 2, 1300),
327+
("pool1", 3, 1250),
328+
("pool2", 3, 1370),
329+
("pool3", 3, 1300),
330+
];
331+
for (pool_id, epoch, stake) in stake_distribution {
332+
let mut statement = connection.prepare(&query)?;
333+
334+
statement.bind(1, *pool_id).unwrap();
335+
statement.bind(2, *epoch).unwrap();
336+
statement.bind(3, *stake).unwrap();
337+
statement.next().unwrap();
338+
}
339+
340+
Ok(())
341+
}
342+
343+
#[cfg(test)]
344+
mod tests {
304345
use super::*;
305346

306347
#[test]
@@ -357,54 +398,10 @@ mod tests {
357398
assert_eq!(vec![Value::Integer(5)], params);
358399
}
359400

360-
fn setup_db(connection: &Connection) -> Result<(), StdError> {
361-
let migrations = get_migrations();
362-
let migration =
363-
migrations
364-
.iter()
365-
.find(|&m| m.version == 1)
366-
.ok_or_else(|| -> StdError {
367-
"There should be a migration version 1".to_string().into()
368-
})?;
369-
let query = {
370-
// leverage the expanded parameter from this provider which is unit
371-
// tested on its own above.
372-
let update_provider = UpdateStakePoolProvider::new(connection);
373-
let (sql_values, _) = update_provider
374-
.get_update_condition("pool_id", Epoch(1), 1000)
375-
.expand();
376-
377-
connection.execute(&migration.alterations)?;
378-
379-
format!("insert into stake_pool {sql_values}")
380-
};
381-
let stake_distribution: &[(&str, i64, i64); 9] = &[
382-
("pool1", 1, 1000),
383-
("pool2", 1, 1100),
384-
("pool3", 1, 1300),
385-
("pool1", 2, 1230),
386-
("pool2", 2, 1090),
387-
("pool3", 2, 1300),
388-
("pool1", 3, 1250),
389-
("pool2", 3, 1370),
390-
("pool3", 3, 1300),
391-
];
392-
for (pool_id, epoch, stake) in stake_distribution {
393-
let mut statement = connection.prepare(&query)?;
394-
395-
statement.bind(1, *pool_id).unwrap();
396-
statement.bind(2, *epoch).unwrap();
397-
statement.bind(3, *stake).unwrap();
398-
statement.next().unwrap();
399-
}
400-
401-
Ok(())
402-
}
403-
404401
#[test]
405402
fn test_get_stake_pools() {
406403
let connection = Connection::open(":memory:").unwrap();
407-
setup_db(&connection).unwrap();
404+
setup_stake_db(&connection).unwrap();
408405

409406
let provider = StakePoolProvider::new(&connection);
410407
let mut cursor = provider.get_by_epoch(&Epoch(1)).unwrap();
@@ -430,7 +427,7 @@ mod tests {
430427
#[test]
431428
fn test_update_stakes() {
432429
let connection = Connection::open(":memory:").unwrap();
433-
setup_db(&connection).unwrap();
430+
setup_stake_db(&connection).unwrap();
434431

435432
let provider = UpdateStakePoolProvider::new(&connection);
436433
let stake_pool = provider.persist("pool4", Epoch(3), 9999).unwrap();
@@ -452,7 +449,7 @@ mod tests {
452449
#[test]
453450
fn test_prune() {
454451
let connection = Connection::open(":memory:").unwrap();
455-
setup_db(&connection).unwrap();
452+
setup_stake_db(&connection).unwrap();
456453

457454
let provider = DeleteStakePoolProvider::new(&connection);
458455
let cursor = provider.prune(Epoch(2)).unwrap();

mithril-aggregator/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ mod signer_registerer;
2525
mod snapshot_stores;
2626
mod snapshot_uploaders;
2727
mod snapshotter;
28-
//pub mod stake_pools;
28+
pub mod stake_distribution_service;
2929
mod store;
3030
mod tools;
3131

mithril-aggregator/src/multi_signer.rs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use async_trait::async_trait;
22
use chrono::prelude::*;
33
use hex::ToHex;
44
use slog_scope::{debug, trace, warn};
5-
use std::{collections::HashMap, sync::Arc};
5+
use std::sync::Arc;
66
use thiserror::Error;
77

88
use mithril_common::{
@@ -12,7 +12,7 @@ use mithril_common::{
1212
ProtocolPartyId, ProtocolRegistrationError, ProtocolSignerVerificationKey,
1313
ProtocolSingleSignature, ProtocolStakeDistribution,
1414
},
15-
entities::{self, Epoch, SignerWithStake},
15+
entities::{self, Epoch, SignerWithStake, StakeDistribution},
1616
store::{StakeStorer, StoreError},
1717
};
1818

@@ -492,7 +492,7 @@ impl MultiSigner for MultiSignerImpl {
492492
.ok_or_else(ProtocolError::UnavailableBeacon)?
493493
.epoch
494494
.offset_to_recording_epoch();
495-
let stakes = HashMap::from_iter(stakes.iter().cloned());
495+
let stakes = StakeDistribution::from_iter(stakes.iter().cloned());
496496
self.stake_store.save_stakes(epoch, stakes).await?;
497497

498498
Ok(())
@@ -742,10 +742,7 @@ mod tests {
742742
None,
743743
);
744744
let stake_store = StakeStore::new(
745-
Box::new(
746-
MemoryAdapter::<Epoch, HashMap<entities::PartyId, entities::Stake>>::new(None)
747-
.unwrap(),
748-
),
745+
Box::new(MemoryAdapter::<Epoch, StakeDistribution>::new(None).unwrap()),
749746
None,
750747
);
751748
let single_signature_store = SingleSignatureStore::new(

mithril-aggregator/src/runtime/runner.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,6 @@ pub mod tests {
755755
use mithril_common::test_utils::MithrilFixtureBuilder;
756756
use mithril_common::{entities::ProtocolMessagePartKey, test_utils::fake_data};
757757
use mithril_common::{BeaconProviderImpl, CardanoNetwork};
758-
use std::collections::HashMap;
759758
use std::path::Path;
760759
use std::sync::Arc;
761760
use tempfile::NamedTempFile;
@@ -876,7 +875,7 @@ pub mod tests {
876875
let beacon = fake_data::beacon();
877876
let recording_epoch = beacon.epoch.offset_to_recording_epoch();
878877
let stake_distribution: StakeDistribution =
879-
HashMap::from([("a".to_string(), 5), ("b".to_string(), 10)]);
878+
StakeDistribution::from([("a".to_string(), 5), ("b".to_string(), 10)]);
880879

881880
stake_store
882881
.save_stakes(recording_epoch, stake_distribution.clone())

0 commit comments

Comments
 (0)