Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 51 additions & 49 deletions protocols/mdns/src/behaviour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
io,
net::IpAddr,
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::Instant,
};

use futures::{channel::mpsc, Stream, StreamExt};
use if_watch::IfEvent;
use libp2p_core::{transport::PortUse, Endpoint, Multiaddr};
use iface::ListenAddressUpdate;
use libp2p_core::{multiaddr::Protocol, transport::PortUse, Endpoint, Multiaddr};
use libp2p_identity::PeerId;
use libp2p_swarm::{
behaviour::FromSwarm, dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour,
Expand All @@ -64,30 +64,22 @@
/// The IfWatcher type.
type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;

type TaskHandle: Abort;

/// Create a new instance of the `IfWatcher` type.
fn new_watcher() -> Result<Self::Watcher, std::io::Error>;

#[track_caller]
fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle;
}

#[allow(unreachable_pub)] // Not re-exported.
pub trait Abort {
fn abort(self);
fn spawn(task: impl Future<Output = ()> + Send + 'static);
}

/// The type of a [`Behaviour`] using the `async-io` implementation.
#[cfg(feature = "async-io")]
pub mod async_io {
use std::future::Future;

use async_std::task::JoinHandle;
use if_watch::smol::IfWatcher;

use super::Provider;
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer, Abort};
use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer};

#[doc(hidden)]
pub enum AsyncIo {}
Expand All @@ -96,20 +88,13 @@
type Socket = AsyncUdpSocket;
type Timer = AsyncTimer;
type Watcher = IfWatcher;
type TaskHandle = JoinHandle<()>;

fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
IfWatcher::new()
}

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> JoinHandle<()> {
async_std::task::spawn(task)
}
}

impl Abort for JoinHandle<()> {
fn abort(self) {
async_std::task::spawn(self.cancel());
fn spawn(task: impl Future<Output = ()> + Send + 'static) {
async_std::task::spawn(task);
}
}

Expand All @@ -122,10 +107,9 @@
use std::future::Future;

use if_watch::tokio::IfWatcher;
use tokio::task::JoinHandle;

use super::Provider;
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer, Abort};
use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer};

#[doc(hidden)]
pub enum Tokio {}
Expand All @@ -134,20 +118,13 @@
type Socket = TokioUdpSocket;
type Timer = TokioTimer;
type Watcher = IfWatcher;
type TaskHandle = JoinHandle<()>;

fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
IfWatcher::new()
}

fn spawn(task: impl Future<Output = ()> + Send + 'static) -> Self::TaskHandle {
tokio::spawn(task)
}
}

impl Abort for JoinHandle<()> {
fn abort(self) {
JoinHandle::abort(&self)
fn spawn(task: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(task);
}
}

Expand All @@ -168,7 +145,7 @@
if_watch: P::Watcher,

/// Handles to tasks running the mDNS queries.
if_tasks: HashMap<IpAddr, P::TaskHandle>,
if_tasks: HashMap<IpAddr, mpsc::UnboundedSender<ListenAddressUpdate>>,

query_response_receiver: mpsc::Receiver<(PeerId, Multiaddr, Instant)>,
query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
Expand All @@ -186,10 +163,10 @@

/// The current set of listen addresses.
///
/// This is shared across all interface tasks using an [`RwLock`].

Check failure on line 166 in protocols/mdns/src/behaviour.rs

View workflow job for this annotation

GitHub Actions / Check rustdoc intra-doc links

unresolved link to `RwLock`
/// The [`Behaviour`] updates this upon new [`FromSwarm`]
/// events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
listen_addresses: Arc<RwLock<ListenAddresses>>,
listen_addresses: ListenAddresses,

local_peer_id: PeerId,

Expand Down Expand Up @@ -301,10 +278,20 @@
}

fn on_swarm_event(&mut self, event: FromSwarm) {
self.listen_addresses
.write()
.unwrap_or_else(|e| e.into_inner())
.on_swarm_event(&event);
if !self.listen_addresses.on_swarm_event(&event) {
return;
}
if let Some(update) = ListenAddressUpdate::from_swarm(event) {
// Send address update to matching interface task.
if let Some(ip) = update.ip_addr() {
if let Some(tx) = self.if_tasks.get_mut(&ip) {
if tx.unbounded_send(update).is_err() {
tracing::error!("`InterfaceState` for ip {ip} dropped");
self.if_tasks.remove(&ip);
}
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this sender should be bounded. I still have to look into how to solve this best, I am currently unsure how to handle the case that a interface-task is busy and the channel full.

}
}
}

#[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self, cx))]
Expand All @@ -322,25 +309,34 @@
while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
match event {
Ok(IfEvent::Up(inet)) => {
let addr = inet.addr();
if addr.is_loopback() {
let ip_addr = inet.addr();
if ip_addr.is_loopback() {
continue;
}
if addr.is_ipv4() && self.config.enable_ipv6
|| addr.is_ipv6() && !self.config.enable_ipv6
if ip_addr.is_ipv4() && self.config.enable_ipv6
|| ip_addr.is_ipv6() && !self.config.enable_ipv6
{
continue;
}
if let Entry::Vacant(e) = self.if_tasks.entry(addr) {
if let Entry::Vacant(e) = self.if_tasks.entry(ip_addr) {
let (addr_tx, addr_rx) = mpsc::unbounded();

Check failure on line 322 in protocols/mdns/src/behaviour.rs

View workflow job for this annotation

GitHub Actions / clippy (beta)

use of a disallowed method `futures::channel::mpsc::unbounded`

Check failure on line 322 in protocols/mdns/src/behaviour.rs

View workflow job for this annotation

GitHub Actions / clippy (1.83.0)

use of a disallowed method `futures::channel::mpsc::unbounded`
let listen_addresses = self
.listen_addresses
.iter()
.filter(|multiaddr| multiaddr_matches_ip(multiaddr, &ip_addr))
.cloned()
.collect();
match InterfaceState::<P::Socket, P::Timer>::new(
addr,
ip_addr,
self.config.clone(),
self.local_peer_id,
self.listen_addresses.clone(),
listen_addresses,
addr_rx,
self.query_response_sender.clone(),
) {
Ok(iface_state) => {
e.insert(P::spawn(iface_state));
P::spawn(iface_state);
e.insert(addr_tx);
}
Err(err) => {
tracing::error!("failed to create `InterfaceState`: {}", err)
Expand All @@ -349,10 +345,8 @@
}
}
Ok(IfEvent::Down(inet)) => {
if let Some(handle) = self.if_tasks.remove(&inet.addr()) {
if self.if_tasks.remove(&inet.addr()).is_some() {
tracing::info!(instance=%inet.addr(), "dropping instance");

handle.abort();
}
}
Err(err) => tracing::error!("if watch returned an error: {}", err),
Expand Down Expand Up @@ -422,6 +416,14 @@
}
}

fn multiaddr_matches_ip(addr: &Multiaddr, ip: &IpAddr) -> bool {
match addr.iter().next() {
Some(Protocol::Ip4(ipv4)) => &IpAddr::V4(ipv4) == ip,
Some(Protocol::Ip6(ipv6)) => &IpAddr::V6(ipv6) == ip,
_ => false,
}
}

/// Event that can be produced by the `Mdns` behaviour.
#[derive(Debug, Clone)]
pub enum Event {
Expand Down
74 changes: 60 additions & 14 deletions protocols/mdns/src/behaviour/iface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ use std::{
io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::{Duration, Instant},
};

use futures::{channel::mpsc, SinkExt, StreamExt};
use libp2p_core::Multiaddr;
use libp2p_core::{multiaddr::Protocol, Multiaddr};
use libp2p_identity::PeerId;
use libp2p_swarm::ListenAddresses;
use libp2p_swarm::{ExpiredListenAddr, FromSwarm, NewListenAddr};
use socket2::{Domain, Socket, Type};

use self::{
Expand Down Expand Up @@ -71,6 +70,38 @@ impl ProbeState {
}
}

/// Event to inform the [`InterfaceState`] of a change in listening addresses.
#[derive(Debug, Clone)]
pub(crate) enum ListenAddressUpdate {
New(Multiaddr),
Expired(Multiaddr),
}

impl ListenAddressUpdate {
pub(crate) fn from_swarm(event: FromSwarm) -> Option<Self> {
match event {
FromSwarm::NewListenAddr(NewListenAddr { addr, .. }) => {
Some(ListenAddressUpdate::New(addr.clone()))
}
FromSwarm::ExpiredListenAddr(ExpiredListenAddr { addr, .. }) => {
Some(ListenAddressUpdate::Expired(addr.clone()))
}
_ => None,
}
}

pub(crate) fn ip_addr(&self) -> Option<IpAddr> {
let addr = match self {
ListenAddressUpdate::New(a) | ListenAddressUpdate::Expired(a) => a,
};
match addr.iter().next()? {
Protocol::Ip4(a) => Some(IpAddr::V4(a)),
Protocol::Ip6(a) => Some(IpAddr::V6(a)),
_ => None,
}
}
}

/// An mDNS instance for a networking interface. To discover all peers when having multiple
/// interfaces an [`InterfaceState`] is required for each interface.
#[derive(Debug)]
Expand All @@ -81,8 +112,10 @@ pub(crate) struct InterfaceState<U, T> {
recv_socket: U,
/// Send socket.
send_socket: U,

listen_addresses: Arc<RwLock<ListenAddresses>>,
/// Current listening addresses.
listen_addresses: Vec<Multiaddr>,
/// Receiver for listening-address updates from the swarm.
listen_addresses_rx: mpsc::UnboundedReceiver<ListenAddressUpdate>,

query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,

Expand Down Expand Up @@ -119,7 +152,8 @@ where
addr: IpAddr,
config: Config,
local_peer_id: PeerId,
listen_addresses: Arc<RwLock<ListenAddresses>>,
listen_addresses: Vec<Multiaddr>,
listen_addresses_rx: mpsc::UnboundedReceiver<ListenAddressUpdate>,
query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
) -> io::Result<Self> {
tracing::info!(address=%addr, "creating instance on iface address");
Expand Down Expand Up @@ -175,6 +209,7 @@ where
recv_socket,
send_socket,
listen_addresses,
listen_addresses_rx,
query_response_sender,
recv_buffer: [0; 4096],
send_buffer: Default::default(),
Expand Down Expand Up @@ -210,7 +245,21 @@ where
let this = self.get_mut();

loop {
// 1st priority: Low latency: Create packet ASAP after timeout.
// 1st priority: Poll for a change in listen addresses.
match this.listen_addresses_rx.poll_next_unpin(cx) {
Poll::Ready(Some(ListenAddressUpdate::New(addr))) => {
this.listen_addresses.push(addr);
continue;
}
Poll::Ready(Some(ListenAddressUpdate::Expired(addr))) => {
this.listen_addresses.retain(|a| a != &addr);
continue;
}
Poll::Ready(None) => return Poll::Ready(()),
Poll::Pending => {}
}

// 2nd priority: Low latency: Create packet ASAP after timeout.
if this.timeout.poll_next_unpin(cx).is_ready() {
tracing::trace!(address=%this.addr, "sending query on iface");
this.send_buffer.push_back(build_query());
Expand All @@ -229,7 +278,7 @@ where
this.reset_timer();
}

// 2nd priority: Keep local buffers small: Send packets to remote.
// 3d priority: Keep local buffers small: Send packets to remote.
if let Some(packet) = this.send_buffer.pop_front() {
match this.send_socket.poll_write(cx, &packet, this.mdns_socket()) {
Poll::Ready(Ok(_)) => {
Expand All @@ -246,7 +295,7 @@ where
}
}

// 3rd priority: Keep local buffers small: Return discovered addresses.
// 4th priority: Keep local buffers small: Return discovered addresses.
if this.query_response_sender.poll_ready_unpin(cx).is_ready() {
if let Some(discovered) = this.discovered.pop_front() {
match this.query_response_sender.try_send(discovered) {
Expand All @@ -263,7 +312,7 @@ where
}
}

// 4th priority: Remote work: Answer incoming requests.
// 5th priority: Remote work: Answer incoming requests.
match this
.recv_socket
.poll_read(cx, &mut this.recv_buffer)
Expand All @@ -279,10 +328,7 @@ where
this.send_buffer.extend(build_query_response(
query.query_id(),
this.local_peer_id,
this.listen_addresses
.read()
.unwrap_or_else(|e| e.into_inner())
.iter(),
this.listen_addresses.iter(),
this.ttl,
));
continue;
Expand Down
Loading