Skip to content

Commit 48bd20f

Browse files
committed
Add Multiwatcher.
1 parent 18f2d2f commit 48bd20f

File tree

9 files changed

+301
-136
lines changed

9 files changed

+301
-136
lines changed

robusta/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ rust-version.workspace = true
88
[dependencies]
99
bincode = { workspace = true }
1010
bon = { workspace = true }
11+
bytes = { workspace = true }
1112
data-encoding = { workspace = true }
1213
either = { workspace = true }
1314
espresso-types = { workspace = true }
@@ -21,6 +22,7 @@ serde_json = { workspace = true }
2122
thiserror = { workspace = true }
2223
timeboost-types = { path = "../timeboost-types" }
2324
tokio = { workspace = true }
25+
tokio-stream = { workspace = true }
2426
tokio-tungstenite = { workspace = true }
2527
tracing = { workspace = true }
2628
url = { workspace = true }

robusta/src/config.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
use std::{iter::repeat, time::Duration};
1+
use std::{iter::repeat, num::NonZeroUsize, time::Duration};
22

33
use bon::Builder;
4-
use url::{ParseError, Url};
4+
use url::Url;
55

6-
const NUM_DELAYS: usize = 5;
6+
const NUM_DELAYS: NonZeroUsize = NonZeroUsize::new(5).expect("5 > 0");
77

88
#[derive(Debug, Clone, Builder)]
99
pub struct Config {
@@ -12,25 +12,29 @@ pub struct Config {
1212
pub(crate) label: String,
1313

1414
/// Espresso network base URL.
15-
#[builder(with = |s: &str| -> Result<_, ParseError> { Url::parse(s) })]
1615
pub(crate) base_url: Url,
1716

1817
/// Espresso network websocket base URL.
19-
#[builder(with = |s: &str| -> Result<_, ParseError> { Url::parse(s) })]
2018
pub(crate) wss_base_url: Url,
2119

2220
/// The sequence of delays between successive requests.
2321
///
2422
/// The last value is repeated forever.
2523
#[builder(default = [1, 3, 5, 10, 15])]
26-
pub(crate) delays: [u8; NUM_DELAYS],
24+
pub(crate) delays: [u8; NUM_DELAYS.get()],
2725
}
2826

2927
impl Config {
28+
pub fn with_websocket_base_url(&self, u: Url) -> Self {
29+
let mut c = self.clone();
30+
c.wss_base_url = u;
31+
c
32+
}
33+
3034
pub fn delay_iter(&self) -> impl Iterator<Item = Duration> + use<> {
3135
self.delays
3236
.into_iter()
33-
.chain(repeat(self.delays[NUM_DELAYS - 1]))
37+
.chain(repeat(self.delays[NUM_DELAYS.get() - 1]))
3438
.map(|n| Duration::from_secs(n.into()))
3539
}
3640
}

robusta/src/lib.rs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod config;
2+
mod multiwatcher;
23
mod types;
34
mod watcher;
45

@@ -18,8 +19,9 @@ use tracing::{debug, warn};
1819

1920
use crate::types::{TX, TaggedBase64, TransactionsWithProof, VidCommonResponse};
2021

22+
pub use crate::multiwatcher::Multiwatcher;
2123
pub use crate::types::Height;
22-
pub use crate::watcher::{WatchError, watch};
24+
pub use crate::watcher::{WatchError, Watcher};
2325
pub use config::{Config, ConfigBuilder};
2426
pub use espresso_types;
2527

@@ -276,10 +278,7 @@ fn deserialize<T: DeserializeOwned>(d: &[u8]) -> Result<T, Error> {
276278

277279
#[cfg(test)]
278280
mod tests {
279-
use futures::StreamExt;
280-
use tokio::pin;
281-
282-
use super::{Client, Config};
281+
use super::{Client, Config, Watcher};
283282

284283
#[tokio::test]
285284
async fn decaf_smoke() {
@@ -288,17 +287,23 @@ mod tests {
288287
.try_init();
289288

290289
let cfg = Config::builder()
291-
.base_url("https://query.decaf.testnet.espresso.network/v1/")
292-
.unwrap()
293-
.wss_base_url("wss://query.decaf.testnet.espresso.network/v1/")
294-
.unwrap()
290+
.base_url(
291+
"https://query.decaf.testnet.espresso.network/v1/"
292+
.parse()
293+
.unwrap(),
294+
)
295+
.wss_base_url(
296+
"wss://query.decaf.testnet.espresso.network/v1/"
297+
.parse()
298+
.unwrap(),
299+
)
295300
.label("decaf_smoke")
296301
.build();
297302

298303
let clt = Client::new(cfg.clone());
299304
let height = clt.height().await.unwrap();
300-
let headers = super::watch(&cfg, height, None).await.unwrap();
301-
pin!(headers);
302-
assert_eq!(u64::from(height), headers.next().await.unwrap().height());
305+
let mut watcher = Watcher::new(cfg, height, None);
306+
let header = watcher.next().await;
307+
assert_eq!(u64::from(height), header.height());
303308
}
304309
}

robusta/src/multiwatcher.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use std::collections::{BTreeMap, HashMap, HashSet};
2+
3+
use crate::{Config, Height, Watcher};
4+
use espresso_types::{Header, NamespaceId};
5+
use futures::{StreamExt, stream::SelectAll};
6+
use tokio::{spawn, sync::mpsc, task::JoinHandle};
7+
use tokio_stream::wrappers::ReceiverStream;
8+
use tracing::{debug, warn};
9+
use url::Url;
10+
11+
#[derive(Debug)]
12+
pub struct Multiwatcher {
13+
height: Height,
14+
threshold: usize,
15+
watchers: Vec<JoinHandle<()>>,
16+
headers: BTreeMap<Height, HashMap<Header, HashSet<Id>>>,
17+
stream: SelectAll<ReceiverStream<(Id, Header)>>,
18+
}
19+
20+
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
21+
struct Id(usize);
22+
23+
impl Drop for Multiwatcher {
24+
fn drop(&mut self) {
25+
for w in &self.watchers {
26+
w.abort();
27+
}
28+
}
29+
}
30+
31+
impl Multiwatcher {
32+
pub fn new<C, H, I, N>(configs: C, height: H, nsid: N, threshold: usize) -> Self
33+
where
34+
C: IntoIterator<Item = Config>,
35+
H: Into<Height>,
36+
I: IntoIterator<Item = Url>,
37+
N: Into<NamespaceId>,
38+
{
39+
let height = height.into();
40+
let nsid = nsid.into();
41+
let mut stream = SelectAll::new();
42+
let mut watchers = Vec::new();
43+
for (i, c) in configs.into_iter().enumerate() {
44+
let (tx, rx) = mpsc::channel(32);
45+
stream.push(ReceiverStream::new(rx));
46+
watchers.push(spawn(async move {
47+
let id = Id(i);
48+
let mut w = Watcher::new(c, height, nsid);
49+
while tx.send((id, w.next().await)).await.is_ok() {}
50+
}));
51+
}
52+
Self {
53+
height,
54+
threshold,
55+
stream,
56+
watchers,
57+
headers: BTreeMap::from_iter([(height, HashMap::new())]),
58+
}
59+
}
60+
61+
pub async fn next(&mut self) -> Option<Header> {
62+
loop {
63+
let (i, hdr) = self.stream.next().await?;
64+
let h = Height::from(hdr.height());
65+
if Some(h) < self.headers.first_entry().map(|e| *e.key()) {
66+
debug!(%h, "ignoring header below minimum height");
67+
continue;
68+
}
69+
if self.has_voted(h, i) {
70+
warn!(%h, "source sent multiple headers for same height");
71+
continue;
72+
}
73+
let votes = self.headers.entry(h).or_default();
74+
if let Some(ids) = votes.get(&hdr)
75+
&& ids.len() + 1 >= self.threshold
76+
{
77+
self.gc(h);
78+
return Some(hdr);
79+
}
80+
votes.entry(hdr).or_default().insert(i);
81+
}
82+
}
83+
84+
fn has_voted(&self, height: Height, id: Id) -> bool {
85+
let Some(m) = self.headers.get(&height) else {
86+
return false;
87+
};
88+
for v in m.values() {
89+
if v.contains(&id) {
90+
return true;
91+
}
92+
}
93+
false
94+
}
95+
96+
fn gc(&mut self, height: Height) {
97+
self.headers.retain(|h, _| *h >= height);
98+
self.height = height;
99+
}
100+
}

robusta/src/types.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ pub(crate) struct VidCommonResponse {
1919

2020
macro_rules! Primitive {
2121
($name:ident, $t:ty) => {
22-
#[derive(Debug, Copy, Clone, Deserialize, Serialize)]
22+
#[derive(
23+
Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize,
24+
)]
2325
#[serde(transparent)]
2426
pub struct $name($t);
2527

0 commit comments

Comments
 (0)