diff --git a/DEFAULT_CONFIG.json5 b/DEFAULT_CONFIG.json5 index c080ad64d7..9814b2c74b 100644 --- a/DEFAULT_CONFIG.json5 +++ b/DEFAULT_CONFIG.json5 @@ -113,6 +113,9 @@ /// increase factor for the next timeout until next try period_increase_factor: 2, }, + /// Interval in millisecond to check if the listening endpoints changed (i.e. when listening on 0.0.0.0 or [::]). + /// Also update the multicast scouting listening interfaces. Use -1 to disable. + endpoint_poll_interval_ms: 10000, }, /// Configure the session open behavior. diff --git a/commons/zenoh-config/src/defaults.rs b/commons/zenoh-config/src/defaults.rs index 1a3b74ab39..c933a791d0 100644 --- a/commons/zenoh-config/src/defaults.rs +++ b/commons/zenoh-config/src/defaults.rs @@ -55,6 +55,8 @@ pub mod listen { pub const timeout_ms: ModeDependentValue = ModeDependentValue::Unique(0); pub const exit_on_failure: ModeDependentValue = ModeDependentValue::Unique(true); + + pub const endpoint_poll_interval_ms: Option = Some(10_000); } #[allow(non_upper_case_globals)] @@ -180,6 +182,7 @@ impl Default for ListenConfig { peer: Some(vec![]), client: None, }), + endpoint_poll_interval_ms: Some(10_000), exit_on_failure: None, retry: None, } diff --git a/commons/zenoh-config/src/lib.rs b/commons/zenoh-config/src/lib.rs index 19636f7608..062f6ef708 100644 --- a/commons/zenoh-config/src/lib.rs +++ b/commons/zenoh-config/src/lib.rs @@ -488,6 +488,9 @@ validated_struct::validator! { /// if connection timeout exceed, exit from application pub exit_on_failure: Option>, pub retry: Option, + /// Interval in millisecond to check if the listening endpoints changed (e.g. when listening on 0.0.0.0). + /// Also update the multicast scouting listening interfaces. Use -1 to disable. + pub endpoint_poll_interval_ms: Option, }, /// Configure the session open behavior. pub open: #[derive(Default)] diff --git a/commons/zenoh-util/src/net/mod.rs b/commons/zenoh-util/src/net/mod.rs index 777dc886c6..ff94f2ff79 100644 --- a/commons/zenoh-util/src/net/mod.rs +++ b/commons/zenoh-util/src/net/mod.rs @@ -12,6 +12,8 @@ // ZettaScale Zenoh Team, // use std::net::{IpAddr, Ipv6Addr}; +#[cfg(unix)] +use std::sync::RwLock; #[cfg(unix)] use lazy_static::lazy_static; @@ -20,6 +22,8 @@ use pnet_datalink::NetworkInterface; use tokio::net::{TcpSocket, UdpSocket}; use zenoh_core::zconfigurable; #[cfg(unix)] +use zenoh_core::{zread, zwrite}; +#[cfg(unix)] use zenoh_result::zerror; use zenoh_result::{bail, ZResult}; @@ -30,7 +34,7 @@ zconfigurable! { #[cfg(unix)] lazy_static! { - static ref IFACES: Vec = pnet_datalink::interfaces(); + static ref IFACES: RwLock> = RwLock::new(pnet_datalink::interfaces()); } #[cfg(windows)] @@ -68,7 +72,7 @@ unsafe fn get_adapters_addresses(af_spec: i32) -> ZResult> { pub fn get_interface(name: &str) -> ZResult> { #[cfg(unix)] { - for iface in IFACES.iter() { + for iface in zread!(IFACES).iter() { if iface.name == name { for ifaddr in &iface.ips { if ifaddr.is_ipv4() { @@ -131,7 +135,7 @@ pub fn get_interface(name: &str) -> ZResult> { pub fn get_multicast_interfaces() -> Vec { #[cfg(unix)] { - IFACES + zread!(IFACES) .iter() .filter_map(|iface| { if iface.is_up() && iface.is_running() && iface.is_multicast() { @@ -155,7 +159,7 @@ pub fn get_multicast_interfaces() -> Vec { pub fn get_local_addresses(interface: Option<&str>) -> ZResult> { #[cfg(unix)] { - Ok(IFACES + Ok(zread!(IFACES) .iter() .filter(|iface| { if let Some(interface) = interface.as_ref() { @@ -205,7 +209,7 @@ pub fn get_local_addresses(interface: Option<&str>) -> ZResult> { pub fn get_unicast_addresses_of_multicast_interfaces() -> Vec { #[cfg(unix)] { - IFACES + zread!(IFACES) .iter() .filter(|iface| iface.is_up() && iface.is_running() && iface.is_multicast()) .flat_map(|iface| { @@ -228,7 +232,7 @@ pub fn get_unicast_addresses_of_multicast_interfaces() -> Vec { pub fn get_unicast_addresses_of_interface(name: &str) -> ZResult> { #[cfg(unix)] { - match IFACES.iter().find(|iface| iface.name == name) { + match zread!(IFACES).iter().find(|iface| iface.name == name) { Some(iface) => { if !iface.is_up() { bail!("Interface {name} is not up"); @@ -282,7 +286,7 @@ pub fn get_unicast_addresses_of_interface(name: &str) -> ZResult> { pub fn get_index_of_interface(addr: IpAddr) -> ZResult { #[cfg(unix)] { - IFACES + zread!(IFACES) .iter() .find(|iface| iface.ips.iter().any(|ipnet| ipnet.ip() == addr)) .map(|iface| iface.index) @@ -319,12 +323,12 @@ pub fn get_interface_names_by_addr(addr: IpAddr) -> ZResult> { #[cfg(unix)] { if addr.is_unspecified() { - Ok(IFACES + Ok(zread!(IFACES) .iter() .map(|iface| iface.name.clone()) .collect::>()) } else { - Ok(IFACES + Ok(zread!(IFACES) .iter() .filter(|iface| iface.ips.iter().any(|ipnet| ipnet.ip() == addr)) .map(|iface| iface.name.clone()) @@ -435,6 +439,12 @@ pub fn get_ipv6_ipaddrs(interface: Option<&str>) -> Vec { .collect() } +#[cfg(unix)] +pub fn update_iface_cache() { + let mut interfaces = zwrite!(IFACES); + *interfaces = pnet_datalink::interfaces(); +} + #[cfg(any(target_os = "linux", target_os = "android"))] pub fn set_bind_to_device_tcp_socket(socket: &TcpSocket, iface: &str) -> ZResult<()> { socket.bind_device(Some(iface.as_bytes()))?; diff --git a/zenoh/src/api/scouting.rs b/zenoh/src/api/scouting.rs index f8745f995a..a0a5ae7629 100644 --- a/zenoh/src/api/scouting.rs +++ b/zenoh/src/api/scouting.rs @@ -24,7 +24,7 @@ use crate::{ builders::scouting::ScoutBuilder, handlers::{Callback, CallbackParameter, DefaultHandler}, }, - net::runtime::{orchestrator::Loop, Runtime}, + net::runtime::{orchestrator::Loop, Runtime, Scouting}, Config, }; @@ -181,7 +181,7 @@ pub(crate) fn _scout( let task = TerminatableTask::spawn( zenoh_runtime::ZRuntime::Acceptor, async move { - let scout = Runtime::scout(&sockets, what, &addr, move |hello| { + let scout = Scouting::scout(&sockets, what, &addr, move |hello| { let callback = callback.clone(); async move { callback.call(hello.into()); diff --git a/zenoh/src/net/protocol/network.rs b/zenoh/src/net/protocol/network.rs index 3a4acfa3b0..1fce3679bc 100644 --- a/zenoh/src/net/protocol/network.rs +++ b/zenoh/src/net/protocol/network.rs @@ -987,6 +987,21 @@ impl Network { } } + pub(crate) fn update_locators(&mut self) { + self.graph[self.idx].sn += 1; + self.send_on_links( + vec![( + self.idx, + Details { + zid: false, + locators: true, + links: self.full_linkstate || self.router_peers_failover_brokering, + }, + )], + |link| link.transport.get_whatami().unwrap_or(WhatAmI::Peer) == WhatAmI::Router, + ); + } + fn remove_detached_nodes(&mut self) -> Vec<(NodeIndex, Node)> { let mut dfs_stack = vec![self.idx]; let mut visit_map = self.graph.visit_map(); diff --git a/zenoh/src/net/routing/hat/linkstate_peer/mod.rs b/zenoh/src/net/routing/hat/linkstate_peer/mod.rs index 7e5eb2a952..c50c45bc04 100644 --- a/zenoh/src/net/routing/hat/linkstate_peer/mod.rs +++ b/zenoh/src/net/routing/hat/linkstate_peer/mod.rs @@ -277,6 +277,12 @@ impl HatBaseTrait for HatCode { Ok(()) } + fn update_self_locators(&self, tables: &mut Tables) { + if let Some(net) = hat_mut!(tables).linkstatepeers_net.as_mut() { + net.update_locators(); + } + } + fn close_face( &self, tables: &TablesLock, diff --git a/zenoh/src/net/routing/hat/mod.rs b/zenoh/src/net/routing/hat/mod.rs index 38e30607c7..ae915e05e6 100644 --- a/zenoh/src/net/routing/hat/mod.rs +++ b/zenoh/src/net/routing/hat/mod.rs @@ -106,6 +106,8 @@ pub(crate) trait HatBaseTrait { send_declare: &mut SendDeclare, ) -> ZResult<()>; + fn update_self_locators(&self, _tables: &mut Tables) {} + fn handle_oam( &self, tables: &mut Tables, diff --git a/zenoh/src/net/routing/hat/p2p_peer/gossip.rs b/zenoh/src/net/routing/hat/p2p_peer/gossip.rs index 538e49c785..986fc9b562 100644 --- a/zenoh/src/net/routing/hat/p2p_peer/gossip.rs +++ b/zenoh/src/net/routing/hat/p2p_peer/gossip.rs @@ -605,4 +605,19 @@ impl Network { } vec![] } + + pub(super) fn update_locators(&mut self) { + self.graph[self.idx].sn += 1; + self.send_on_links( + vec![( + self.idx, + Details { + zid: false, + locators: true, + links: self.router_peers_failover_brokering, + }, + )], + |link| link.transport.get_whatami().unwrap_or(WhatAmI::Peer) == WhatAmI::Router, + ); + } } diff --git a/zenoh/src/net/routing/hat/p2p_peer/mod.rs b/zenoh/src/net/routing/hat/p2p_peer/mod.rs index 282917b244..248f4f3c78 100644 --- a/zenoh/src/net/routing/hat/p2p_peer/mod.rs +++ b/zenoh/src/net/routing/hat/p2p_peer/mod.rs @@ -216,6 +216,12 @@ impl HatBaseTrait for HatCode { Ok(()) } + fn update_self_locators(&self, tables: &mut Tables) { + if let Some(net) = hat_mut!(tables).gossip.as_mut() { + net.update_locators(); + } + } + fn close_face( &self, tables: &TablesLock, diff --git a/zenoh/src/net/routing/hat/router/mod.rs b/zenoh/src/net/routing/hat/router/mod.rs index 9b6916d967..fb0ab9b48e 100644 --- a/zenoh/src/net/routing/hat/router/mod.rs +++ b/zenoh/src/net/routing/hat/router/mod.rs @@ -460,6 +460,15 @@ impl HatBaseTrait for HatCode { Ok(()) } + fn update_self_locators(&self, tables: &mut Tables) { + if let Some(net) = hat_mut!(tables).routers_net.as_mut() { + net.update_locators(); + } + if let Some(net) = hat_mut!(tables).linkstatepeers_net.as_mut() { + net.update_locators(); + } + } + fn close_face( &self, tables: &TablesLock, diff --git a/zenoh/src/net/runtime/mod.rs b/zenoh/src/net/runtime/mod.rs index 7656234d76..998fa1f736 100644 --- a/zenoh/src/net/runtime/mod.rs +++ b/zenoh/src/net/runtime/mod.rs @@ -19,6 +19,7 @@ //! [Click here for Zenoh's documentation](https://docs.rs/zenoh/latest/zenoh) mod adminspace; pub mod orchestrator; +mod scouting; #[cfg(feature = "plugins")] use std::sync::{Mutex, MutexGuard}; @@ -29,11 +30,13 @@ use std::{ atomic::{AtomicU32, Ordering}, Arc, Weak, }, + time::Duration, }; pub use adminspace::AdminSpace; use async_trait::async_trait; use futures::{stream::StreamExt, Future}; +pub use scouting::Scouting; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use uhlc::{HLCBuilder, HLC}; @@ -55,6 +58,8 @@ use zenoh_transport::{ multicast::TransportMulticast, unicast::TransportUnicast, TransportEventHandler, TransportManager, TransportMulticastEventHandler, TransportPeer, TransportPeerEventHandler, }; +#[cfg(unix)] +use zenoh_util::net::update_iface_cache; use self::orchestrator::StartConditions; use super::{primitives::DeMux, routing, routing::router::Router}; @@ -87,6 +92,7 @@ pub(crate) struct RuntimeState { plugins_manager: Mutex, start_conditions: Arc, pending_connections: tokio::sync::Mutex>, + scouting: tokio::sync::Mutex>, } pub struct WeakRuntime { @@ -175,6 +181,7 @@ impl RuntimeBuilder { // SHM lazy init flag #[cfg(feature = "shared-memory")] let shm_init_mode = *config.transport.shared_memory.mode(); + let endpoint_poll_interval = config.listen.endpoint_poll_interval_ms().unwrap_or(10_000); let config = Notifier::new(crate::config::Config(config)); let runtime = Runtime { @@ -193,6 +200,7 @@ impl RuntimeBuilder { plugins_manager: Mutex::new(plugins_manager), start_conditions: Arc::new(StartConditions::default()), pending_connections: tokio::sync::Mutex::new(HashSet::new()), + scouting: tokio::sync::Mutex::new(None), }), }; *handler.runtime.write().unwrap() = Runtime::downgrade(&runtime); @@ -240,6 +248,14 @@ impl RuntimeBuilder { zenoh_config::ShmInitMode::Lazy => {} }; + if endpoint_poll_interval > 0 { + let poll_interval = Duration::from_millis(endpoint_poll_interval as u64); + runtime.spawn({ + let runtime2 = runtime.clone(); + async move { runtime2.monitor_available_addrs(poll_interval).await } + }); + } + Ok(runtime) } } @@ -366,6 +382,33 @@ impl Runtime { pub(crate) async fn remove_pending_connection(&self, zid: &ZenohIdProto) -> bool { self.state.pending_connections.lock().await.remove(zid) } + + async fn monitor_available_addrs(&self, poll_interval: Duration) { + let token = self.get_cancellation_token(); + loop { + tokio::select! { + _ = tokio::time::sleep(poll_interval) => self.update_available_addrs().await, + _ = token.cancelled() => return, + } + } + } + + async fn update_available_addrs(&self) { + #[cfg(unix)] + update_iface_cache(); + + if self.update_locators() { + let tables_lock = &self.state.router.tables; + let _ctrl_lock = zlock!(tables_lock.ctrl_lock); + let mut tables = zwrite!(tables_lock.tables); + tables_lock.hat_code.update_self_locators(&mut tables); + } + + let scouting = self.state.scouting.lock().await; + if let Some(scouting) = scouting.as_ref() { + scouting.update_addrs_if_needed().await; + } + } } struct RuntimeTransportEventHandler { @@ -547,6 +590,7 @@ impl Closee for Arc { self.manager.close().await; // clean up to break cyclic reference of self.state to itself self.transport_handlers.write().unwrap().clear(); + zasynclock!(self.scouting).take(); // TODO: the call below is needed to prevent intermittent leak // due to not freed resource Arc, that apparently happens because // the task responsible for resource clean up was aborted earlier than expected. diff --git a/zenoh/src/net/runtime/orchestrator.rs b/zenoh/src/net/runtime/orchestrator.rs index 910e7b4164..4f0dbf10e3 100644 --- a/zenoh/src/net/runtime/orchestrator.rs +++ b/zenoh/src/net/runtime/orchestrator.rs @@ -18,34 +18,22 @@ use std::{ time::Duration, }; -use futures::prelude::*; use socket2::{Domain, Socket, Type}; use tokio::{ net::UdpSocket, sync::{futures::Notified, Mutex, Notify}, }; -use zenoh_buffers::{ - reader::{DidntRead, HasReader}, - writer::HasWriter, -}; -use zenoh_codec::{RCodec, WCodec, Zenoh080}; use zenoh_config::{ get_global_connect_timeout, get_global_listener_timeout, unwrap_or_default, ModeDependent, }; use zenoh_link::{Locator, LocatorInspector}; -use zenoh_protocol::{ - core::{whatami::WhatAmIMatcher, EndPoint, Metadata, PriorityRange, WhatAmI, ZenohIdProto}, - scouting::{HelloProto, Scout, ScoutingBody, ScoutingMessage}, +use zenoh_protocol::core::{ + whatami::WhatAmIMatcher, EndPoint, Metadata, PriorityRange, WhatAmI, ZenohIdProto, }; use zenoh_result::{bail, zerror, ZResult}; use super::{Runtime, RuntimeSession}; -use crate::net::{common::AutoConnect, protocol::linkstate::LinkInfo}; - -const RCV_BUF_SIZE: usize = u16::MAX as usize; -const SCOUT_INITIAL_PERIOD: Duration = Duration::from_millis(1_000); -const SCOUT_MAX_PERIOD: Duration = Duration::from_millis(8_000); -const SCOUT_PERIOD_INCREASE_FACTOR: u32 = 2; +use crate::net::{common::AutoConnect, protocol::linkstate::LinkInfo, runtime::scouting::Scouting}; pub enum Loop { Continue, @@ -276,42 +264,19 @@ impl Runtime { let config = &config_guard.0; unwrap_or_default!(config.scouting().multicast().ttl()) }; - let ifaces = Runtime::get_interfaces(&ifaces); - let mcast_socket = Runtime::bind_mcast_port(&addr, &ifaces, multicast_ttl).await?; - if !ifaces.is_empty() { - let sockets: Vec = ifaces - .into_iter() - .filter_map(|iface| Runtime::bind_ucast_port(iface, multicast_ttl).ok()) - .collect(); - if !sockets.is_empty() { - let this = self.clone(); - match (listen, autoconnect.is_enabled()) { - (true, true) => { - self.spawn_abortable(async move { - tokio::select! { - _ = this.responder(&mcast_socket, &sockets) => {}, - _ = this.autoconnect_all( - &sockets, - autoconnect, - &addr - ) => {}, - } - }); - } - (true, false) => { - self.spawn_abortable(async move { - this.responder(&mcast_socket, &sockets).await; - }); - } - (false, true) => { - self.spawn_abortable(async move { - this.autoconnect_all(&sockets, autoconnect, &addr).await - }); - } - _ => {} - } - } - } + + let scouting = Scouting::new( + listen, + autoconnect, + addr, + ifaces, + multicast_ttl, + self.clone(), + ) + .await?; + scouting.start().await?; + *self.state.scouting.lock().await = Some(scouting.clone()); + Ok(()) } @@ -537,7 +502,7 @@ impl Runtime { self.spawn_add_listener(endpoint, retry_config).await } } - self.print_locators(); + self.update_locators(); Ok(()) } @@ -549,7 +514,7 @@ impl Runtime { let this = self.clone(); self.spawn(async move { this.add_listener_retry(listener, retry_config).await; - this.print_locators(); + this.update_locators(); }); } @@ -579,12 +544,26 @@ impl Runtime { Ok(()) } - fn print_locators(&self) { + pub fn update_locators(&self) -> bool { let mut locators = self.state.locators.write().unwrap(); - *locators = self.manager().get_locators(); - for locator in &*locators { - tracing::info!("Zenoh can be reached at: {}", locator); + let new_locators = self.manager().get_locators(); + if are_locators_equal(&locators, &new_locators) { + return false; + } + if tracing::enabled!(tracing::Level::INFO) { + for locator in &new_locators { + if !locators.contains(locator) { + tracing::info!("Zenoh can be reached at: {}", locator); + } + } + for old_locator in &*locators { + if !new_locators.contains(old_locator) { + tracing::info!("Zenoh can no longer be reached at: {}", old_locator); + } + } } + *locators = new_locators; + true } pub fn get_interfaces(names: &str) -> Vec { @@ -821,102 +800,6 @@ impl Runtime { } } - pub async fn scout( - sockets: &[UdpSocket], - matcher: WhatAmIMatcher, - mcast_addr: &SocketAddr, - f: F, - ) where - F: Fn(HelloProto) -> Fut + std::marker::Send + std::marker::Sync + Clone, - Fut: Future + std::marker::Send, - Self: Sized, - { - let send = async { - let mut delay = SCOUT_INITIAL_PERIOD; - - let scout: ScoutingMessage = Scout { - version: zenoh_protocol::VERSION, - what: matcher, - zid: None, - } - .into(); - let mut wbuf = vec![]; - let mut writer = wbuf.writer(); - let codec = Zenoh080::new(); - codec.write(&mut writer, &scout).unwrap(); - - loop { - for socket in sockets { - tracing::trace!( - "Send {:?} to {} on interface {}", - scout.body, - mcast_addr, - socket - .local_addr() - .map_or("unknown".to_string(), |addr| addr.ip().to_string()) - ); - if let Err(err) = socket - .send_to(wbuf.as_slice(), mcast_addr.to_string()) - .await - { - tracing::debug!( - "Unable to send {:?} to {} on interface {}: {}", - scout.body, - mcast_addr, - socket - .local_addr() - .map_or("unknown".to_string(), |addr| addr.ip().to_string()), - err - ); - } - } - tokio::time::sleep(delay).await; - if delay * SCOUT_PERIOD_INCREASE_FACTOR <= SCOUT_MAX_PERIOD { - delay *= SCOUT_PERIOD_INCREASE_FACTOR; - } - } - }; - let recvs = futures::future::select_all(sockets.iter().map(move |socket| { - let f = f.clone(); - async move { - let mut buf = vec![0; RCV_BUF_SIZE]; - loop { - match socket.recv_from(&mut buf).await { - Ok((n, peer)) => { - let mut reader = buf.as_slice()[..n].reader(); - let codec = Zenoh080::new(); - let res: Result = codec.read(&mut reader); - if let Ok(msg) = res { - tracing::trace!("Received {:?} from {}", msg.body, peer); - if let ScoutingBody::Hello(hello) = &msg.body { - if matcher.matches(hello.whatami) { - if let Loop::Break = f(hello.clone()).await { - break; - } - } else { - tracing::warn!("Received unexpected Hello: {:?}", msg.body); - } - } - } else { - tracing::trace!( - "Received unexpected UDP datagram from {}: {:?}", - peer, - &buf.as_slice()[..n] - ); - } - } - Err(e) => tracing::debug!("Error receiving UDP datagram: {}", e), - } - } - } - .boxed() - })); - tokio::select! { - _ = send => {}, - _ = recvs => {}, - } - } - /// Returns `true` if a new Transport instance is established with `zid` or had already been established. #[must_use] async fn connect(&self, zid: &ZenohIdProto, scouted_locators: &[Locator]) -> bool { @@ -1063,7 +946,7 @@ impl Runtime { timeout: std::time::Duration, ) -> ZResult<()> { let scout = async { - Runtime::scout(sockets, what, addr, move |hello| async move { + Scouting::scout(sockets, what, addr, move |hello| async move { tracing::info!("Found {:?}", hello); if !hello.locators.is_empty() { if self.connect(&hello.zid, &hello.locators).await { @@ -1087,110 +970,6 @@ impl Runtime { } } - async fn autoconnect_all( - &self, - ucast_sockets: &[UdpSocket], - autoconnect: AutoConnect, - addr: &SocketAddr, - ) { - Runtime::scout( - ucast_sockets, - autoconnect.matcher(), - addr, - move |hello| async move { - if hello.locators.is_empty() { - tracing::warn!("Received Hello with no locators: {:?}", hello); - } else if autoconnect.should_autoconnect(hello.zid, hello.whatami) { - self.connect_peer(&hello.zid, &hello.locators).await; - } - Loop::Continue - }, - ) - .await - } - - async fn responder(&self, mcast_socket: &UdpSocket, ucast_sockets: &[UdpSocket]) { - fn get_best_match<'a>(addr: &IpAddr, sockets: &'a [UdpSocket]) -> Option<&'a UdpSocket> { - fn octets(addr: &IpAddr) -> Vec { - match addr { - IpAddr::V4(addr) => addr.octets().to_vec(), - IpAddr::V6(addr) => addr.octets().to_vec(), - } - } - fn matching_octets(addr: &IpAddr, sock: &UdpSocket) -> usize { - octets(addr) - .iter() - .zip(octets(&sock.local_addr().unwrap().ip())) - .map(|(x, y)| x.cmp(&y)) - .position(|ord| ord != std::cmp::Ordering::Equal) - .unwrap_or_else(|| octets(addr).len()) - } - sockets - .iter() - .filter(|sock| sock.local_addr().is_ok()) - .max_by(|sock1, sock2| { - matching_octets(addr, sock1).cmp(&matching_octets(addr, sock2)) - }) - } - - let mut buf = vec![0; RCV_BUF_SIZE]; - let local_addrs: Vec = ucast_sockets - .iter() - .filter_map(|sock| sock.local_addr().ok()) - .collect(); - tracing::debug!("Waiting for UDP datagram..."); - loop { - let (n, peer) = mcast_socket.recv_from(&mut buf).await.unwrap(); - if local_addrs.contains(&peer) { - tracing::trace!("Ignore UDP datagram from own socket"); - continue; - } - - let mut reader = buf.as_slice()[..n].reader(); - let codec = Zenoh080::new(); - let res: Result = codec.read(&mut reader); - if let Ok(msg) = res { - tracing::trace!("Received {:?} from {}", msg.body, peer); - if let ScoutingBody::Scout(Scout { what, .. }) = &msg.body { - if what.matches(self.whatami()) { - let mut wbuf = vec![]; - let mut writer = wbuf.writer(); - let codec = Zenoh080::new(); - - let zid = self.manager().zid(); - let hello: ScoutingMessage = HelloProto { - version: zenoh_protocol::VERSION, - whatami: self.whatami(), - zid, - locators: self.get_locators(), - } - .into(); - let socket = get_best_match(&peer.ip(), ucast_sockets).unwrap(); - tracing::trace!( - "Send {:?} to {} on interface {}", - hello.body, - peer, - socket - .local_addr() - .map_or("unknown".to_string(), |addr| addr.ip().to_string()) - ); - codec.write(&mut writer, &hello).unwrap(); - - if let Err(err) = socket.send_to(wbuf.as_slice(), peer).await { - tracing::error!("Unable to send {:?} to {}: {}", hello.body, peer, err); - } - } - } - } else { - tracing::trace!( - "Received unexpected UDP datagram from {}: {:?}", - peer, - &buf.as_slice()[..n] - ); - } - } - } - pub(super) fn closed_session(session: &RuntimeSession) { if session.runtime.is_closed() { return; @@ -1251,3 +1030,9 @@ impl Runtime { router.tables.hat_code.links_info(&tables) } } + +fn are_locators_equal(a: &[Locator], b: &[Locator]) -> bool { + use std::collections::hash_map::RandomState; + a.len() == b.len() + && HashSet::<&Locator, RandomState>::from_iter(a.iter()) == HashSet::from_iter(b.iter()) +} diff --git a/zenoh/src/net/runtime/scouting.rs b/zenoh/src/net/runtime/scouting.rs new file mode 100644 index 0000000000..c7d4ecd685 --- /dev/null +++ b/zenoh/src/net/runtime/scouting.rs @@ -0,0 +1,466 @@ +use std::{ + collections::HashSet, + future::Future, + io::ErrorKind, + net::{IpAddr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use futures::{lock::Mutex, FutureExt}; +use itertools::Itertools; +use tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; +use tokio_util::sync::CancellationToken; +use zenoh_buffers::{ + reader::{DidntRead, HasReader}, + writer::HasWriter, +}; +use zenoh_codec::{RCodec, WCodec, Zenoh080}; +use zenoh_protocol::{ + core::{Locator, WhatAmI, WhatAmIMatcher, ZenohIdProto}, + scouting::{HelloProto, Scout, ScoutingBody, ScoutingMessage}, +}; +use zenoh_result::ZResult; + +use super::Runtime; +use crate::net::{common::AutoConnect, runtime::orchestrator::Loop}; + +const RCV_BUF_SIZE: usize = u16::MAX as usize; +const SCOUT_INITIAL_PERIOD: Duration = Duration::from_millis(1_000); +const SCOUT_MAX_PERIOD: Duration = Duration::from_millis(8_000); +const SCOUT_PERIOD_INCREASE_FACTOR: u32 = 2; + +#[derive(Clone)] +pub struct Scouting { + state: Arc, +} + +struct ScoutState { + listen: bool, + autoconnect: AutoConnect, + /// The multicast address to send scout messages to. + addr: SocketAddr, + /// Interface constraints, "auto" or a comma-separated IP address list.. + ifaces: String, + multicast_ttl: u32, + runtime: Runtime, + sockets: RwLock, + cancellation_token: Mutex, +} + +struct ScoutSockets { + mcast_socket: UdpSocket, + ucast_sockets: Vec, +} + +impl Scouting { + pub async fn new( + listen: bool, + autoconnect: AutoConnect, + addr: SocketAddr, + ifaces: String, + multicast_ttl: u32, + runtime: Runtime, + ) -> ZResult { + let ifaces_ips = Runtime::get_interfaces(&ifaces); + let mcast_socket = Runtime::bind_mcast_port(&addr, &ifaces_ips, multicast_ttl).await?; + let ucast_sockets = ifaces_ips + .into_iter() + .filter_map(|iface| Runtime::bind_ucast_port(iface, multicast_ttl).ok()) + .collect(); + + let sockets = RwLock::new(ScoutSockets { + mcast_socket, + ucast_sockets, + }); + let cancellation_token = Mutex::new(CancellationToken::new()); + + let state = Arc::new(ScoutState { + listen, + autoconnect, + addr, + ifaces, + multicast_ttl, + runtime, + sockets, + cancellation_token, + }); + + Ok(Scouting { state }) + } + + pub async fn update_addrs_if_needed(&self) { + let available_mcast_addrs = Runtime::get_interfaces(&self.state.ifaces) + .into_iter() + .collect::>(); + let used_mcast_addrs = zasyncread!(self.state.sockets) + .ucast_sockets + .iter() + .filter_map(|s| s.local_addr().ok()) + .map(|s| s.ip()) + .collect::>(); + let new_addrs = available_mcast_addrs + .difference(&used_mcast_addrs) + .collect::>(); + let obsolete_addrs = used_mcast_addrs + .difference(&available_mcast_addrs) + .collect::>(); + + if !new_addrs.is_empty() || !obsolete_addrs.is_empty() { + if let Err(e) = self + .update_scouting_addresses(&new_addrs, &obsolete_addrs) + .await + { + tracing::error!( + "Could not update scouting addresses with +{:?}, -{:?}: {}", + new_addrs, + obsolete_addrs, + e + ); + }; + } + } + + async fn update_scouting_addresses( + &self, + addrs_to_add: &[&IpAddr], + addrs_to_remove: &[&IpAddr], + ) -> ZResult<()> { + tracing::debug!("Join multicast scouting on {addrs_to_add:?}"); + for addr_to_add in addrs_to_add { + self.join_multicast_group(addr_to_add).await; + } + + tracing::debug!("Restarting scout routine"); + // TODO: This may interrupt something important, as a connection establishment... fix that. + zasynclock!(self.state.cancellation_token).cancel(); + + { + let mut sockets = zasyncwrite!(self.state.sockets); + sockets.ucast_sockets.retain(|s| { + s.local_addr().map_or(true, |a| { + let ip = a.ip(); + if addrs_to_remove.iter().copied().contains(&ip) { + tracing::debug!("Removing socket udp/{}", ip); + false + } else { + true + } + }) + }); + sockets.ucast_sockets.extend( + addrs_to_add + .iter() + .filter_map(|&i| Runtime::bind_ucast_port(*i, self.state.multicast_ttl).ok()), + ); + } + + *zasynclock!(self.state.cancellation_token) = CancellationToken::new(); + + self.start().await?; + tracing::debug!("Scout routine restarted"); + + Ok(()) + } + + async fn join_multicast_group(&self, interface_addr: &IpAddr) { + let sockets = zasyncread!(self.state.sockets); + if let (IpAddr::V4(new_addr), IpAddr::V4(mcast_addr)) = + (interface_addr, self.state.addr.ip()) + { + match sockets + .mcast_socket + .join_multicast_v4(mcast_addr, *new_addr) + { + Ok(()) => tracing::debug!( + "Joined multicast group {} on interface {}", + mcast_addr, + new_addr, + ), + // We already joined the multicast group + Err(err) if err.kind() == ErrorKind::AddrInUse => (), + Err(err) => tracing::warn!( + "Unable to join multicast group {} on interface {}: {}", + mcast_addr, + new_addr, + err, + ), + }; + } + } + + pub async fn scout( + sockets: &[UdpSocket], + matcher: WhatAmIMatcher, + mcast_addr: &SocketAddr, + f: F, + ) where + F: Fn(HelloProto) -> Fut + std::marker::Send + std::marker::Sync + Clone, + Fut: Future + std::marker::Send, + Self: Sized, + { + // This can be interrupted anytime: we cannot send "half" beacons, + // and there's no repercusion if we miss one send. + let send = async { + let mut delay = SCOUT_INITIAL_PERIOD; + + let scout: ScoutingMessage = Scout { + version: zenoh_protocol::VERSION, + what: matcher, + zid: None, + } + .into(); + let mut wbuf = vec![]; + let mut writer = wbuf.writer(); + let codec = Zenoh080::new(); + codec.write(&mut writer, &scout).unwrap(); + + loop { + for socket in sockets { + tracing::trace!( + "Send {:?} to {} on interface {}", + scout.body, + mcast_addr, + socket + .local_addr() + .map_or("unknown".to_string(), |addr| addr.ip().to_string()) + ); + if let Err(err) = socket + .send_to(wbuf.as_slice(), mcast_addr.to_string()) + .await + { + tracing::debug!( + "Unable to send {:?} to {} on interface {}: {}", + scout.body, + mcast_addr, + socket + .local_addr() + .map_or("unknown".to_string(), |addr| addr.ip().to_string()), + err + ); + } + } + tokio::time::sleep(delay).await; + if delay * SCOUT_PERIOD_INCREASE_FACTOR <= SCOUT_MAX_PERIOD { + delay *= SCOUT_PERIOD_INCREASE_FACTOR; + } + } + }; + let recvs = futures::future::select_all(sockets.iter().map(move |socket| { + let f = f.clone(); + async move { + let mut buf = vec![0; RCV_BUF_SIZE]; + loop { + match socket.recv_from(&mut buf).await { + Ok((n, peer)) => { + let mut reader = buf.as_slice()[..n].reader(); + let codec = Zenoh080::new(); + let res: Result = codec.read(&mut reader); + if let Ok(msg) = res { + tracing::trace!("Received {:?} from {}", msg.body, peer); + if let ScoutingBody::Hello(hello) = &msg.body { + if matcher.matches(hello.whatami) { + if let Loop::Break = f(hello.clone()).await { + break; + } + } else { + tracing::warn!("Received unexpected Hello: {:?}", msg.body); + } + } + } else { + tracing::trace!( + "Received unexpected UDP datagram from {}: {:?}", + peer, + &buf.as_slice()[..n] + ); + } + } + Err(e) => tracing::debug!("Error receiving UDP datagram: {}", e), + } + } + } + .boxed() + })); + tokio::select! { + _ = send => {}, + _ = recvs => {}, + } + } + + pub async fn start(&self) -> ZResult<()> { + if !zasyncread!(self.state.sockets).ucast_sockets.is_empty() { + let this = self.clone(); + let token = this.get_cancellation_token().await; + match (self.state.listen, self.state.autoconnect.is_enabled()) { + (true, true) => { + self.spawn_abortable(async move { + let sockets = zasyncread!(this.state.sockets); + tokio::select! { + _ = this.responder(&sockets.mcast_socket, &sockets.ucast_sockets) => {}, + _ = this.autoconnect_all( + &sockets.ucast_sockets, + this.state.autoconnect, + &this.state.addr + ) => {}, + _ = token.cancelled() => (), + } + }); + } + (true, false) => { + self.spawn_abortable(async move { + let sockets = zasyncread!(this.state.sockets); + tokio::select! { + _ = this.responder(&sockets.mcast_socket, &sockets.ucast_sockets) => (), + _ = token.cancelled() => (), + } + }); + } + (false, true) => { + self.spawn_abortable(async move { + let sockets = zasyncread!(this.state.sockets); + tokio::select! { + _ = this.autoconnect_all( + &sockets.ucast_sockets, + this.state.autoconnect, + &this.state.addr, + ) => (), + _ = token.cancelled() => (), + } + }); + } + _ => {} + } + } + Ok(()) + } + + async fn autoconnect_all( + &self, + ucast_sockets: &[UdpSocket], + autoconnect: AutoConnect, + addr: &SocketAddr, + ) { + Self::scout( + ucast_sockets, + autoconnect.matcher(), + addr, + move |hello| async move { + if hello.locators.is_empty() { + tracing::warn!("Received Hello with no locators: {:?}", hello); + } else if autoconnect.should_autoconnect(hello.zid, hello.whatami) { + self.connect_peer(&hello.zid, &hello.locators).await; + } + Loop::Continue + }, + ) + .await + } + + async fn responder(&self, mcast_socket: &UdpSocket, ucast_sockets: &[UdpSocket]) { + let mut buf = vec![0; RCV_BUF_SIZE]; + let local_addrs: Vec = ucast_sockets + .iter() + .filter_map(|sock| sock.local_addr().ok()) + .collect(); + tracing::debug!("Waiting for UDP datagram..."); + loop { + let (n, peer) = mcast_socket.recv_from(&mut buf).await.unwrap(); + if local_addrs.contains(&peer) { + tracing::trace!("Ignore UDP datagram from own socket"); + continue; + } + + let mut reader = buf.as_slice()[..n].reader(); + let codec = Zenoh080::new(); + let res: Result = codec.read(&mut reader); + if let Ok(msg) = res { + tracing::trace!("Received {:?} from {}", msg.body, peer); + if let ScoutingBody::Scout(Scout { what, .. }) = &msg.body { + if what.matches(self.whatami()) { + let mut wbuf = vec![]; + let mut writer = wbuf.writer(); + let codec = Zenoh080::new(); + + let zid = self.zid(); + let hello: ScoutingMessage = HelloProto { + version: zenoh_protocol::VERSION, + whatami: self.whatami(), + zid, + locators: self.get_locators(), + } + .into(); + let socket = get_best_match(&peer.ip(), ucast_sockets).unwrap(); + tracing::trace!( + "Send {:?} to {} on interface {}", + hello.body, + peer, + socket + .local_addr() + .map_or("unknown".to_string(), |addr| addr.ip().to_string()) + ); + codec.write(&mut writer, &hello).unwrap(); + + if let Err(err) = socket.send_to(wbuf.as_slice(), peer).await { + tracing::error!("Unable to send {:?} to {}: {}", hello.body, peer, err); + } + } + } + } else { + tracing::trace!( + "Received unexpected UDP datagram from {}: {:?}", + peer, + &buf.as_slice()[..n] + ); + } + } + } + + fn whatami(&self) -> WhatAmI { + self.state.runtime.whatami() + } + + fn get_locators(&self) -> Vec { + self.state.runtime.get_locators() + } + + async fn get_cancellation_token(&self) -> CancellationToken { + zasynclock!(self.state.cancellation_token).child_token() + } + + fn zid(&self) -> ZenohIdProto { + self.state.runtime.manager().zid() + } + + async fn connect_peer(&self, zid: &ZenohIdProto, locators: &[Locator]) -> bool { + self.state.runtime.connect_peer(zid, locators).await + } + + fn spawn_abortable(&self, future: F) -> JoinHandle<()> + where + F: Future + Send + 'static, + T: Send + 'static, + { + self.state.runtime.spawn_abortable(future) + } +} + +fn get_best_match<'a>(addr: &IpAddr, sockets: &'a [UdpSocket]) -> Option<&'a UdpSocket> { + fn octets(addr: &IpAddr) -> Vec { + match addr { + IpAddr::V4(addr) => addr.octets().to_vec(), + IpAddr::V6(addr) => addr.octets().to_vec(), + } + } + fn matching_octets(addr: &IpAddr, sock: &UdpSocket) -> usize { + octets(addr) + .iter() + .zip(octets(&sock.local_addr().unwrap().ip())) + .map(|(x, y)| x.cmp(&y)) + .position(|ord| ord != std::cmp::Ordering::Equal) + .unwrap_or_else(|| octets(addr).len()) + } + sockets + .iter() + .filter(|sock| sock.local_addr().is_ok()) + .max_by(|sock1, sock2| matching_octets(addr, sock1).cmp(&matching_octets(addr, sock2))) +}