Skip to content

Commit dc147fe

Browse files
committed
feat(aggregator-discovery): add 'ShuffleAggregatorDiscoverer' decorator
1 parent aa0a62f commit dc147fe

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/mithril-aggregator-discovery/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ include = ["**/*.rs", "Cargo.toml", "README.md", ".gitignore"]
1313
[lib]
1414
crate-type = ["lib", "cdylib", "staticlib"]
1515

16+
[features]
17+
rand = ["dep:rand"]
18+
1619
[dependencies]
1720
anyhow = { workspace = true }
1821
async-trait = { workspace = true }
1922
mithril-common = { path = "../../mithril-common" }
23+
rand = { version = "0.9.2", optional = true}
2024
reqwest = { workspace = true, features = [
2125
"default",
2226
"gzip",

internal/mithril-aggregator-discovery/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,12 @@
44
mod http_config_discoverer;
55
mod interface;
66
mod model;
7+
#[cfg(feature = "rand")]
8+
mod rand_discoverer;
79
pub mod test;
810

911
pub use http_config_discoverer::HttpConfigAggregatorDiscoverer;
1012
pub use interface::AggregatorDiscoverer;
1113
pub use model::{AggregatorEndpoint, MithrilNetwork};
14+
#[cfg(feature = "rand")]
15+
pub use rand_discoverer::ShuffleAggregatorDiscoverer;
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
use std::sync::Arc;
2+
3+
use rand::{Rng, seq::SliceRandom};
4+
use tokio::sync::Mutex;
5+
6+
use mithril_common::StdResult;
7+
8+
use crate::{AggregatorDiscoverer, AggregatorEndpoint, MithrilNetwork};
9+
10+
/// A discoverer that returns a random set of aggregators
11+
pub struct ShuffleAggregatorDiscoverer<R: Rng + Send + Sized> {
12+
random_generator: Arc<Mutex<R>>,
13+
inner_discoverer: Arc<dyn AggregatorDiscoverer>,
14+
}
15+
16+
impl<R: Rng + Send + Sized> ShuffleAggregatorDiscoverer<R> {
17+
/// Creates a new `ShuffleAggregatorDiscoverer` instance with the provided inner discoverer.
18+
pub fn new(inner_discoverer: Arc<dyn AggregatorDiscoverer>, random_generator: R) -> Self {
19+
Self {
20+
inner_discoverer,
21+
random_generator: Arc::new(Mutex::new(random_generator)),
22+
}
23+
}
24+
}
25+
26+
#[async_trait::async_trait]
27+
impl<R: Rng + Send + Sized> AggregatorDiscoverer for ShuffleAggregatorDiscoverer<R> {
28+
async fn get_available_aggregators(
29+
&self,
30+
network: MithrilNetwork,
31+
) -> StdResult<Box<dyn Iterator<Item = AggregatorEndpoint>>> {
32+
let mut aggregators: Vec<AggregatorEndpoint> = self
33+
.inner_discoverer
34+
.get_available_aggregators(network)
35+
.await?
36+
.collect();
37+
let mut rng = self.random_generator.lock().await;
38+
aggregators.shuffle(&mut *rng);
39+
40+
Ok(Box::new(aggregators.into_iter()))
41+
}
42+
}
43+
44+
#[cfg(test)]
45+
mod tests {
46+
use rand::{SeedableRng, rngs::StdRng};
47+
48+
use crate::test::double::AggregatorDiscovererFake;
49+
50+
use super::*;
51+
52+
#[tokio::test]
53+
async fn shuffle_aggregator_discoverer() {
54+
let inner_discoverer = AggregatorDiscovererFake::new(vec![Ok(vec![
55+
AggregatorEndpoint::new("https://release-devnet-aggregator1".to_string()),
56+
AggregatorEndpoint::new("https://release-devnet-aggregator2".to_string()),
57+
AggregatorEndpoint::new("https://release-devnet-aggregator3".to_string()),
58+
])]);
59+
let seed = [0u8; 32];
60+
let rng = StdRng::from_seed(seed);
61+
let discoverer = ShuffleAggregatorDiscoverer::new(Arc::new(inner_discoverer), rng);
62+
63+
let aggregators = discoverer
64+
.get_available_aggregators(MithrilNetwork::new("release-devnet".into()))
65+
.await
66+
.unwrap();
67+
68+
assert_eq!(
69+
vec![
70+
AggregatorEndpoint::new("https://release-devnet-aggregator3".into()),
71+
AggregatorEndpoint::new("https://release-devnet-aggregator2".into()),
72+
AggregatorEndpoint::new("https://release-devnet-aggregator1".into()),
73+
],
74+
aggregators.collect::<Vec<_>>()
75+
);
76+
}
77+
}

0 commit comments

Comments
 (0)