diff --git a/src/lib.rs b/src/lib.rs index b24fd35..d0c7b8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -180,6 +180,15 @@ pub(crate) struct JoinPeers { #[derive(Debug, Serialize, Deserialize)] pub(crate) struct Shutdown; +#[derive(Debug, Serialize, Deserialize)] +pub struct Neighbors; + +#[derive(Debug, Serialize, Deserialize, Clone, Copy, Eq, PartialEq)] +pub enum NeighborEvent { + Up(NodeId), + Down(NodeId), +} + #[derive(Debug, Serialize, Deserialize)] #[rpc_requests(message = Message)] enum Proto { @@ -191,6 +200,8 @@ enum Proto { Subscribe(Subscribe), #[rpc(tx = oneshot::Sender<()>)] JoinPeers(JoinPeers), + #[rpc(tx = mpsc::Sender)] + Neighbors(Neighbors), #[rpc(tx = oneshot::Sender<()>)] Shutdown(Shutdown), } @@ -477,6 +488,33 @@ impl Client { self.0.rpc(peers) } + /// Watch neighbor up and down events. + /// + /// The list of neighbors at the time of calling will be emitted as [`NeighborEvent::Up`] right away. + pub fn neighbors( + &self, + ) -> impl n0_future::Stream> + Send + 'static { + self.0 + .server_streaming(Neighbors, 16) + .map_ok(|r| r.into_stream().err_into()) + .try_flatten_stream() + } + + /// Wait until at least one peer connection was established. + pub async fn joined(&self) -> irpc::Result<()> { + let stream = self.neighbors(); + tokio::pin!(stream); + loop { + match stream.next().await { + Some(Ok(NeighborEvent::Up(_))) => break Ok(()), + Some(Ok(NeighborEvent::Down(_))) => {} + Some(Err(err)) => break Err(err), + // TODO: proper error + None => break Err(irpc::channel::oneshot::RecvError::SenderClosed.into()), + } + } + } + pub async fn shutdown(&self) -> Result<(), irpc::Error> { let _ = self.0.rpc(Shutdown).await?; Ok(()) @@ -588,18 +626,21 @@ struct Actor { config: Config, state: State, broadcast_tx: broadcast::Sender, + neighbor_tx: broadcast::Sender, } impl Actor { fn new(topic: GossipTopic, rx: tokio::sync::mpsc::Receiver, config: Config) -> Self { let (sender, receiver) = topic.split(); let (broadcast_tx, _) = tokio::sync::broadcast::channel(32); + let (neighbor_tx, _) = tokio::sync::broadcast::channel(32); Self { state: State::new(), sender, receiver, rx, broadcast_tx, + neighbor_tx, config, } } @@ -725,6 +766,33 @@ impl Actor { } } + async fn handle_neighbors( + mut rx: broadcast::Receiver, + tx: mpsc::Sender, + initial: impl IntoIterator, + ) { + for peer in initial { + if tx.send(NeighborEvent::Up(peer)).await.is_err() { + return; + } + } + loop { + tokio::select! { + _ = tx.closed() => return, + ev = rx.recv() => match ev { + Ok(ev) => if let Err(_) = tx.send(ev).await { + return; + } + Err(broadcast::error::RecvError::Closed) => return, + Err(broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("Slow watch_peers subscriber (lagged by {n})"); + continue; + } + } + } + } + } + async fn run(mut self) { let mut tasks = FuturesUnordered::>::new(); let mut buf = bytes::BytesMut::with_capacity(4096); @@ -778,6 +846,11 @@ impl Actor { break; } } + Message::Neighbors(msg) => { + let rx = self.neighbor_tx.subscribe(); + let initial: Vec<_> = self.receiver.neighbors().collect(); + tasks.push(Box::pin(Self::handle_neighbors(rx, msg.tx, initial))); + } Message::Shutdown(msg) => { msg.tx.send(()).await.ok(); break; @@ -802,10 +875,12 @@ impl Actor { Event::NeighborUp(peer) => { trace!("New peer {}, starting fast anti-entropy", peer.fmt_short()); anti_entropy.set(Self::anti_entropy(self.state.snapshot(), self.sender.clone(), self.config.fast_anti_entropy_interval)); + self.neighbor_tx.send(NeighborEvent::Up(peer)).ok(); continue; }, Event::NeighborDown(peer) => { trace!("Peer down: {}, goodbye!", peer.fmt_short()); + self.neighbor_tx.send(NeighborEvent::Down(peer)).ok(); continue; }, e => { @@ -1213,3 +1288,48 @@ mod peg_parser { filter_parser::filter(input).map_err(|e| e.to_string()) } } + +#[cfg(test)] +mod tests { + + use iroh::protocol::Router; + use iroh_gossip::{net::Gossip, proto::TopicId}; + use n0_snafu::{Result, ResultExt}; + + use super::*; + + async fn spawn(bootstrap: Vec) -> Result<(Router, Client)> { + let topic_id = TopicId::from_bytes([0u8; 32]); + let ep = iroh::Endpoint::builder().discovery_n0().bind().await?; + ep.online().await; + let gossip = Gossip::builder().spawn(ep.clone()); + let router = Router::builder(ep) + .accept(iroh_gossip::ALPN, gossip.clone()) + .spawn(); + let topic = gossip.subscribe(topic_id, bootstrap).await?; + let kv = Client::local(topic, Default::default()); + Ok((router, kv)) + } + + #[tokio::test] + async fn test_watch_neighbors() -> Result<()> { + tracing_subscriber::fmt::init(); + let (r1, kv1) = spawn(vec![]).await?; + let (r2, kv2) = spawn(vec![r1.endpoint().node_id()]).await?; + + let s = kv1.neighbors(); + tokio::pin!(s); + let ev = s.next().await.unwrap().unwrap(); + assert_eq!(ev, NeighborEvent::Up(r2.endpoint().node_id())); + + let s = kv2.neighbors(); + tokio::pin!(s); + let ev = s.next().await.unwrap().unwrap(); + assert_eq!(ev, NeighborEvent::Up(r1.endpoint().node_id())); + + r1.shutdown().await.e()?; + r2.shutdown().await.e()?; + + Ok(()) + } +}