diff --git a/iroh/src/discovery/mdns.rs b/iroh/src/discovery/mdns.rs index 96c7607856..25c892db0e 100644 --- a/iroh/src/discovery/mdns.rs +++ b/iroh/src/discovery/mdns.rs @@ -73,6 +73,7 @@ pub struct MdnsDiscovery { #[allow(dead_code)] handle: AbortOnDropHandle<()>, sender: mpsc::Sender, + advertise: bool, /// When `local_addrs` changes, we re-publish our info. local_addrs: Watchable>, } @@ -127,39 +128,73 @@ impl Subscribers { /// Builder for [`MdnsDiscovery`]. #[derive(Debug)] -pub struct MdnsDiscoveryBuilder; +pub struct MdnsDiscoveryBuilder { + advertise: bool, +} + +impl MdnsDiscoveryBuilder { + /// Creates a new [`MdnsDiscoveryBuilder`] with default settings. + pub fn new() -> Self { + Self { advertise: true } + } + + /// Sets whether this node should advertise its presence. + /// + /// Default is true. + pub fn advertise(mut self, advertise: bool) -> Self { + self.advertise = advertise; + self + } + + /// Builds an [`MdnsDiscovery`] instance with the configured settings. + pub fn build(self, node_id: NodeId) -> Result { + MdnsDiscovery::new(node_id, self.advertise) + } +} + +impl Default for MdnsDiscoveryBuilder { + fn default() -> Self { + Self::new() + } +} impl IntoDiscovery for MdnsDiscoveryBuilder { fn into_discovery( self, context: &DiscoveryContext, ) -> Result { - MdnsDiscovery::new(context.node_id()) + self.build(context.node_id()) } } impl MdnsDiscovery { /// Returns a [`MdnsDiscoveryBuilder`] that implements [`IntoDiscovery`]. pub fn builder() -> MdnsDiscoveryBuilder { - MdnsDiscoveryBuilder + MdnsDiscoveryBuilder::new() } /// Create a new [`MdnsDiscovery`] Service. /// - /// This starts a [`Discoverer`] that broadcasts your addresses and receives addresses from other nodes in your local network. + /// This starts a [`Discoverer`] that broadcasts your addresses (if advertise is set to true) + /// and receives addresses from other nodes in your local network. /// /// # Errors /// Returns an error if the network does not allow ipv4 OR ipv6. /// /// # Panics /// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime. - pub fn new(node_id: NodeId) -> Result { + pub fn new(node_id: NodeId, advertise: bool) -> Result { debug!("Creating new MdnsDiscovery service"); let (send, mut recv) = mpsc::channel(64); let task_sender = send.clone(); let rt = tokio::runtime::Handle::current(); - let discovery = - MdnsDiscovery::spawn_discoverer(node_id, task_sender.clone(), BTreeSet::new(), &rt)?; + let discovery = MdnsDiscovery::spawn_discoverer( + node_id, + advertise, + task_sender.clone(), + BTreeSet::new(), + &rt, + )?; let local_addrs: Watchable> = Watchable::default(); let mut addrs_change = local_addrs.watch(); @@ -311,6 +346,7 @@ impl MdnsDiscovery { let handle = task::spawn(discovery_fut.instrument(info_span!("swarm-discovery.actor"))); Ok(Self { handle: AbortOnDropHandle::new(handle), + advertise, sender: send, local_addrs, }) @@ -318,6 +354,7 @@ impl MdnsDiscovery { fn spawn_discoverer( node_id: PublicKey, + advertise: bool, sender: mpsc::Sender, socketaddrs: BTreeSet, rt: &tokio::runtime::Handle, @@ -337,15 +374,17 @@ impl MdnsDiscovery { sender.send(Message::Discovery(node_id, peer)).await.ok(); }); }; - let addrs = MdnsDiscovery::socketaddrs_to_addrs(&socketaddrs); let node_id_str = data_encoding::BASE32_NOPAD .encode(node_id.as_bytes()) .to_ascii_lowercase(); let mut discoverer = Discoverer::new_interactive(N0_LOCAL_SWARM.to_string(), node_id_str) .with_callback(callback) .with_ip_class(IpClass::Auto); - for addr in addrs { - discoverer = discoverer.with_addrs(addr.0, addr.1); + if advertise { + let addrs = MdnsDiscovery::socketaddrs_to_addrs(&socketaddrs); + for addr in addrs { + discoverer = discoverer.with_addrs(addr.0, addr.1); + } } discoverer .spawn(rt) @@ -406,7 +445,9 @@ impl Discovery for MdnsDiscovery { } fn publish(&self, data: &NodeData) { - self.local_addrs.set(Some(data.clone())).ok(); + if self.advertise { + self.local_addrs.set(Some(data.clone())).ok(); + } } fn subscribe(&self) -> Option> { @@ -440,8 +481,10 @@ mod tests { #[tokio::test] #[traced_test] async fn mdns_publish_resolve() -> Result { - let (_, discovery_a) = make_discoverer()?; - let (node_id_b, discovery_b) = make_discoverer()?; + // Create discoverer A with advertise=false (only listens) + let (_, discovery_a) = make_discoverer(false)?; + // Create discoverer B with advertise=true (will broadcast) + let (node_id_b, discovery_b) = make_discoverer(true)?; // make addr info for discoverer b let user_data: UserData = "foobar".parse()?; @@ -477,11 +520,11 @@ mod tests { let mut node_ids = BTreeSet::new(); let mut discoverers = vec![]; - let (_, discovery) = make_discoverer()?; + let (_, discovery) = make_discoverer(false)?; let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()])); for i in 0..num_nodes { - let (node_id, discovery) = make_discoverer()?; + let (node_id, discovery) = make_discoverer(true)?; let user_data: UserData = format!("node{i}").parse()?; let node_data = node_data.clone().with_user_data(Some(user_data.clone())); node_ids.insert((node_id, Some(user_data))); @@ -513,9 +556,38 @@ mod tests { .context("timeout")? } - fn make_discoverer() -> Result<(PublicKey, MdnsDiscovery)> { + #[tokio::test] + #[traced_test] + async fn non_advertising_node_not_discovered() -> Result { + let (_, discovery_a) = make_discoverer(false)?; + let (node_id_b, discovery_b) = make_discoverer(false)?; + + let (node_id_c, discovery_c) = make_discoverer(true)?; + let node_data_c = + NodeData::new(None, BTreeSet::from(["0.0.0.0:22222".parse().unwrap()])); + discovery_c.publish(&node_data_c); + + let node_data_b = + NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()])); + discovery_b.publish(&node_data_b); + + let mut stream_c = discovery_a.resolve(node_id_c).unwrap(); + let result_c = tokio::time::timeout(Duration::from_secs(2), stream_c.next()).await; + assert!(result_c.is_ok(), "Advertising node should be discoverable"); + + let mut stream_b = discovery_a.resolve(node_id_b).unwrap(); + let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await; + assert!( + result_b.is_err(), + "Expected timeout since node b isn't advertising" + ); + + Ok(()) + } + + fn make_discoverer(advertise: bool) -> Result<(PublicKey, MdnsDiscovery)> { let node_id = SecretKey::generate(rand::thread_rng()).public(); - Ok((node_id, MdnsDiscovery::new(node_id)?)) + Ok((node_id, MdnsDiscovery::new(node_id, advertise)?)) } } }