|
| 1 | +use std::collections::{BTreeMap, HashMap, HashSet}; |
| 2 | + |
| 3 | +use crate::{Config, Height, Watcher}; |
| 4 | +use espresso_types::{Header, NamespaceId}; |
| 5 | +use futures::{StreamExt, stream::SelectAll}; |
| 6 | +use tokio::{spawn, sync::mpsc, task::JoinHandle}; |
| 7 | +use tokio_stream::wrappers::ReceiverStream; |
| 8 | +use tracing::{debug, warn}; |
| 9 | +use url::Url; |
| 10 | + |
| 11 | +#[derive(Debug)] |
| 12 | +pub struct Multiwatcher { |
| 13 | + height: Height, |
| 14 | + threshold: usize, |
| 15 | + watchers: Vec<JoinHandle<()>>, |
| 16 | + headers: BTreeMap<Height, HashMap<Header, HashSet<Id>>>, |
| 17 | + stream: SelectAll<ReceiverStream<(Id, Header)>>, |
| 18 | +} |
| 19 | + |
| 20 | +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] |
| 21 | +struct Id(usize); |
| 22 | + |
| 23 | +impl Drop for Multiwatcher { |
| 24 | + fn drop(&mut self) { |
| 25 | + for w in &self.watchers { |
| 26 | + w.abort(); |
| 27 | + } |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +impl Multiwatcher { |
| 32 | + pub fn new<C, H, I, N>(configs: C, height: H, nsid: N, threshold: usize) -> Self |
| 33 | + where |
| 34 | + C: IntoIterator<Item = Config>, |
| 35 | + H: Into<Height>, |
| 36 | + I: IntoIterator<Item = Url>, |
| 37 | + N: Into<NamespaceId>, |
| 38 | + { |
| 39 | + let height = height.into(); |
| 40 | + let nsid = nsid.into(); |
| 41 | + let mut stream = SelectAll::new(); |
| 42 | + let mut watchers = Vec::new(); |
| 43 | + for (i, c) in configs.into_iter().enumerate() { |
| 44 | + let (tx, rx) = mpsc::channel(32); |
| 45 | + stream.push(ReceiverStream::new(rx)); |
| 46 | + watchers.push(spawn(async move { |
| 47 | + let id = Id(i); |
| 48 | + let mut w = Watcher::new(c, height, nsid); |
| 49 | + while tx.send((id, w.next().await)).await.is_ok() {} |
| 50 | + })); |
| 51 | + } |
| 52 | + Self { |
| 53 | + height, |
| 54 | + threshold, |
| 55 | + stream, |
| 56 | + watchers, |
| 57 | + headers: BTreeMap::from_iter([(height, HashMap::new())]), |
| 58 | + } |
| 59 | + } |
| 60 | + |
| 61 | + pub async fn next(&mut self) -> Option<Header> { |
| 62 | + loop { |
| 63 | + let (i, hdr) = self.stream.next().await?; |
| 64 | + let h = Height::from(hdr.height()); |
| 65 | + if Some(h) < self.headers.first_entry().map(|e| *e.key()) { |
| 66 | + debug!(%h, "ignoring header below minimum height"); |
| 67 | + continue; |
| 68 | + } |
| 69 | + if self.has_voted(h, i) { |
| 70 | + warn!(%h, "source sent multiple headers for same height"); |
| 71 | + continue; |
| 72 | + } |
| 73 | + let votes = self.headers.entry(h).or_default(); |
| 74 | + if let Some(ids) = votes.get(&hdr) |
| 75 | + && ids.len() + 1 >= self.threshold |
| 76 | + { |
| 77 | + self.gc(h); |
| 78 | + return Some(hdr); |
| 79 | + } |
| 80 | + votes.entry(hdr).or_default().insert(i); |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + fn has_voted(&self, height: Height, id: Id) -> bool { |
| 85 | + let Some(m) = self.headers.get(&height) else { |
| 86 | + return false; |
| 87 | + }; |
| 88 | + for v in m.values() { |
| 89 | + if v.contains(&id) { |
| 90 | + return true; |
| 91 | + } |
| 92 | + } |
| 93 | + false |
| 94 | + } |
| 95 | + |
| 96 | + fn gc(&mut self, height: Height) { |
| 97 | + self.headers.retain(|h, _| *h >= height); |
| 98 | + self.height = height; |
| 99 | + } |
| 100 | +} |
0 commit comments