|
| 1 | +use std::{ |
| 2 | + collections::{BTreeMap, HashMap, HashSet}, |
| 3 | + sync::{ |
| 4 | + Arc, |
| 5 | + atomic::{AtomicU64, Ordering}, |
| 6 | + }, |
| 7 | +}; |
| 8 | + |
| 9 | +use crate::{Config, Height, Watcher}; |
| 10 | +use espresso_types::{Header, NamespaceId}; |
| 11 | +use futures::{StreamExt, stream::SelectAll}; |
| 12 | +use tokio::{ |
| 13 | + spawn, |
| 14 | + sync::{Barrier, mpsc}, |
| 15 | + task::JoinHandle, |
| 16 | +}; |
| 17 | +use tokio_stream::wrappers::ReceiverStream; |
| 18 | +use tracing::{debug, warn}; |
| 19 | + |
| 20 | +#[derive(Debug)] |
| 21 | +pub struct Multiwatcher { |
| 22 | + threshold: usize, |
| 23 | + lower_bound: Arc<AtomicU64>, |
| 24 | + watchers: Vec<JoinHandle<()>>, |
| 25 | + headers: BTreeMap<Height, HashMap<Header, HashSet<Id>>>, |
| 26 | + stream: SelectAll<ReceiverStream<(Id, Header)>>, |
| 27 | +} |
| 28 | + |
| 29 | +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] |
| 30 | +struct Id(usize); |
| 31 | + |
| 32 | +impl Drop for Multiwatcher { |
| 33 | + fn drop(&mut self) { |
| 34 | + for w in &self.watchers { |
| 35 | + w.abort(); |
| 36 | + } |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +impl Multiwatcher { |
| 41 | + pub fn new<C, H, N>(configs: C, height: H, nsid: N, threshold: usize) -> Self |
| 42 | + where |
| 43 | + C: IntoIterator<Item = Config>, |
| 44 | + H: Into<Height>, |
| 45 | + N: Into<NamespaceId>, |
| 46 | + { |
| 47 | + let height = height.into(); |
| 48 | + let nsid = nsid.into(); |
| 49 | + |
| 50 | + // We require `threshold` watchers to deliver the next header. |
| 51 | + // Adversaries may produce headers in quick succession, causing |
| 52 | + // excessive memory usage. |
| 53 | + let barrier = Arc::new(Barrier::new(threshold)); |
| 54 | + |
| 55 | + // We track the last delivered height as a lower bound. |
| 56 | + // Watchers skip over headers up to and including that height. |
| 57 | + let lower_bound = Arc::new(AtomicU64::new(height.into())); |
| 58 | + |
| 59 | + let mut stream = SelectAll::new(); |
| 60 | + let mut watchers = Vec::new(); |
| 61 | + |
| 62 | + for (i, c) in configs.into_iter().enumerate() { |
| 63 | + let (tx, rx) = mpsc::channel(10); |
| 64 | + stream.push(ReceiverStream::new(rx)); |
| 65 | + let barrier = barrier.clone(); |
| 66 | + let lower_bound = lower_bound.clone(); |
| 67 | + watchers.push(spawn(async move { |
| 68 | + let i = Id(i); |
| 69 | + let mut w = Watcher::new(c, height, nsid); |
| 70 | + loop { |
| 71 | + let h = w.next().await; |
| 72 | + if h.height() <= lower_bound.load(Ordering::Relaxed) { |
| 73 | + continue; |
| 74 | + } |
| 75 | + if tx.send((i, h)).await.is_err() { |
| 76 | + break; |
| 77 | + } |
| 78 | + barrier.wait().await; |
| 79 | + } |
| 80 | + })); |
| 81 | + } |
| 82 | + |
| 83 | + assert!(!watchers.is_empty()); |
| 84 | + |
| 85 | + Self { |
| 86 | + threshold, |
| 87 | + stream, |
| 88 | + watchers, |
| 89 | + lower_bound, |
| 90 | + headers: BTreeMap::from_iter([(height, HashMap::new())]), |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + pub async fn next(&mut self) -> Header { |
| 95 | + loop { |
| 96 | + let (i, hdr) = self.stream.next().await.expect("watchers never terminate"); |
| 97 | + let h = Height::from(hdr.height()); |
| 98 | + if Some(h) < self.headers.first_entry().map(|e| *e.key()) { |
| 99 | + debug!(height = %h, "ignoring header below minimum height"); |
| 100 | + continue; |
| 101 | + } |
| 102 | + if self.has_voted(h, i) { |
| 103 | + warn!(height = %h, "source sent multiple headers for same height"); |
| 104 | + continue; |
| 105 | + } |
| 106 | + let counter = self.headers.entry(h).or_default(); |
| 107 | + let votes = counter.get(&hdr).map(|ids| ids.len()).unwrap_or(0) + 1; |
| 108 | + if votes >= self.threshold { |
| 109 | + self.headers.retain(|k, _| *k > h); |
| 110 | + self.lower_bound.store(h.into(), Ordering::Relaxed); |
| 111 | + debug!(height = %h, "header available"); |
| 112 | + return hdr; |
| 113 | + } |
| 114 | + debug!(height = %h, %votes, "vote added"); |
| 115 | + counter.entry(hdr).or_default().insert(i); |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + fn has_voted(&self, height: Height, id: Id) -> bool { |
| 120 | + let Some(m) = self.headers.get(&height) else { |
| 121 | + return false; |
| 122 | + }; |
| 123 | + for ids in m.values() { |
| 124 | + if ids.contains(&id) { |
| 125 | + return true; |
| 126 | + } |
| 127 | + } |
| 128 | + false |
| 129 | + } |
| 130 | +} |
0 commit comments