Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 89 additions & 17 deletions iroh/src/discovery/mdns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub struct MdnsDiscovery {
#[allow(dead_code)]
handle: AbortOnDropHandle<()>,
sender: mpsc::Sender<Message>,
advertise: bool,
/// When `local_addrs` changes, we re-publish our info.
local_addrs: Watchable<Option<NodeData>>,
}
Expand Down Expand Up @@ -127,39 +128,73 @@ impl Subscribers {

/// Builder for [`MdnsDiscovery`].
#[derive(Debug)]
pub struct MdnsDiscoveryBuilder;
pub struct MdnsDiscoveryBuilder {
advertise: bool,
}

impl MdnsDiscoveryBuilder {
/// Creates a new [`MdnsDiscoveryBuilder`] with default settings.
pub fn new() -> Self {
Self { advertise: true }
}

/// Sets whether this node should advertise its presence.
///
/// Default is true.
pub fn advertise(mut self, advertise: bool) -> Self {
self.advertise = advertise;
self
}

/// Builds an [`MdnsDiscovery`] instance with the configured settings.
pub fn build(self, node_id: NodeId) -> Result<MdnsDiscovery, IntoDiscoveryError> {
MdnsDiscovery::new(node_id, self.advertise)
}
}

impl Default for MdnsDiscoveryBuilder {
fn default() -> Self {
Self::new()
}
}

impl IntoDiscovery for MdnsDiscoveryBuilder {
fn into_discovery(
self,
context: &DiscoveryContext,
) -> Result<impl Discovery, IntoDiscoveryError> {
MdnsDiscovery::new(context.node_id())
self.build(context.node_id())
}
}

impl MdnsDiscovery {
/// Returns a [`MdnsDiscoveryBuilder`] that implements [`IntoDiscovery`].
pub fn builder() -> MdnsDiscoveryBuilder {
MdnsDiscoveryBuilder
MdnsDiscoveryBuilder::new()
}

/// Create a new [`MdnsDiscovery`] Service.
///
/// This starts a [`Discoverer`] that broadcasts your addresses and receives addresses from other nodes in your local network.
/// This starts a [`Discoverer`] that broadcasts your addresses (if advertise is set to true)
/// and receives addresses from other nodes in your local network.
///
/// # Errors
/// Returns an error if the network does not allow ipv4 OR ipv6.
///
/// # Panics
/// This relies on [`tokio::runtime::Handle::current`] and will panic if called outside of the context of a tokio runtime.
pub fn new(node_id: NodeId) -> Result<Self, IntoDiscoveryError> {
pub fn new(node_id: NodeId, advertise: bool) -> Result<Self, IntoDiscoveryError> {
debug!("Creating new MdnsDiscovery service");
let (send, mut recv) = mpsc::channel(64);
let task_sender = send.clone();
let rt = tokio::runtime::Handle::current();
let discovery =
MdnsDiscovery::spawn_discoverer(node_id, task_sender.clone(), BTreeSet::new(), &rt)?;
let discovery = MdnsDiscovery::spawn_discoverer(
node_id,
advertise,
task_sender.clone(),
BTreeSet::new(),
&rt,
)?;

let local_addrs: Watchable<Option<NodeData>> = Watchable::default();
let mut addrs_change = local_addrs.watch();
Expand Down Expand Up @@ -311,13 +346,15 @@ impl MdnsDiscovery {
let handle = task::spawn(discovery_fut.instrument(info_span!("swarm-discovery.actor")));
Ok(Self {
handle: AbortOnDropHandle::new(handle),
advertise,
sender: send,
local_addrs,
})
}

fn spawn_discoverer(
node_id: PublicKey,
advertise: bool,
sender: mpsc::Sender<Message>,
socketaddrs: BTreeSet<SocketAddr>,
rt: &tokio::runtime::Handle,
Expand All @@ -337,15 +374,17 @@ impl MdnsDiscovery {
sender.send(Message::Discovery(node_id, peer)).await.ok();
});
};
let addrs = MdnsDiscovery::socketaddrs_to_addrs(&socketaddrs);
let node_id_str = data_encoding::BASE32_NOPAD
.encode(node_id.as_bytes())
.to_ascii_lowercase();
let mut discoverer = Discoverer::new_interactive(N0_LOCAL_SWARM.to_string(), node_id_str)
.with_callback(callback)
.with_ip_class(IpClass::Auto);
for addr in addrs {
discoverer = discoverer.with_addrs(addr.0, addr.1);
if advertise {
let addrs = MdnsDiscovery::socketaddrs_to_addrs(&socketaddrs);
for addr in addrs {
discoverer = discoverer.with_addrs(addr.0, addr.1);
}
}
discoverer
.spawn(rt)
Expand Down Expand Up @@ -406,7 +445,9 @@ impl Discovery for MdnsDiscovery {
}

fn publish(&self, data: &NodeData) {
self.local_addrs.set(Some(data.clone())).ok();
if self.advertise {
self.local_addrs.set(Some(data.clone())).ok();
}
}

fn subscribe(&self) -> Option<BoxStream<DiscoveryItem>> {
Expand Down Expand Up @@ -440,8 +481,10 @@ mod tests {
#[tokio::test]
#[traced_test]
async fn mdns_publish_resolve() -> Result {
let (_, discovery_a) = make_discoverer()?;
let (node_id_b, discovery_b) = make_discoverer()?;
// Create discoverer A with advertise=false (only listens)
let (_, discovery_a) = make_discoverer(false)?;
// Create discoverer B with advertise=true (will broadcast)
let (node_id_b, discovery_b) = make_discoverer(true)?;

// make addr info for discoverer b
let user_data: UserData = "foobar".parse()?;
Expand Down Expand Up @@ -477,11 +520,11 @@ mod tests {
let mut node_ids = BTreeSet::new();
let mut discoverers = vec![];

let (_, discovery) = make_discoverer()?;
let (_, discovery) = make_discoverer(false)?;
let node_data = NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()]));

for i in 0..num_nodes {
let (node_id, discovery) = make_discoverer()?;
let (node_id, discovery) = make_discoverer(true)?;
let user_data: UserData = format!("node{i}").parse()?;
let node_data = node_data.clone().with_user_data(Some(user_data.clone()));
node_ids.insert((node_id, Some(user_data)));
Expand Down Expand Up @@ -513,9 +556,38 @@ mod tests {
.context("timeout")?
}

fn make_discoverer() -> Result<(PublicKey, MdnsDiscovery)> {
#[tokio::test]
#[traced_test]
async fn non_advertising_node_not_discovered() -> Result {
let (_, discovery_a) = make_discoverer(false)?;
let (node_id_b, discovery_b) = make_discoverer(false)?;

let (node_id_c, discovery_c) = make_discoverer(true)?;
let node_data_c =
NodeData::new(None, BTreeSet::from(["0.0.0.0:22222".parse().unwrap()]));
discovery_c.publish(&node_data_c);

let node_data_b =
NodeData::new(None, BTreeSet::from(["0.0.0.0:11111".parse().unwrap()]));
discovery_b.publish(&node_data_b);

let mut stream_c = discovery_a.resolve(node_id_c).unwrap();
let result_c = tokio::time::timeout(Duration::from_secs(2), stream_c.next()).await;
assert!(result_c.is_ok(), "Advertising node should be discoverable");

let mut stream_b = discovery_a.resolve(node_id_b).unwrap();
let result_b = tokio::time::timeout(Duration::from_secs(2), stream_b.next()).await;
assert!(
result_b.is_err(),
"Expected timeout since node b isn't advertising"
);

Ok(())
}

fn make_discoverer(advertise: bool) -> Result<(PublicKey, MdnsDiscovery)> {
let node_id = SecretKey::generate(rand::thread_rng()).public();
Ok((node_id, MdnsDiscovery::new(node_id)?))
Ok((node_id, MdnsDiscovery::new(node_id, advertise)?))
}
}
}
Loading