diff --git a/monad-dataplane/examples/node.rs b/monad-dataplane/examples/node.rs index a1401d0ace..15b105615c 100644 --- a/monad-dataplane/examples/node.rs +++ b/monad-dataplane/examples/node.rs @@ -26,7 +26,10 @@ use futures::{executor, Stream}; use futures_util::FutureExt; use monad_dataplane::{ udp::DEFAULT_SEGMENT_SIZE, BroadcastMsg, Dataplane, DataplaneBuilder, RecvUdpMsg, TcpMsg, + UdpSocketHandle, }; + +const LEGACY_SOCKET: &str = "legacy"; use rand::Rng; const NODE_ONE_ADDR: &str = "127.0.0.1:60000"; @@ -80,16 +83,18 @@ fn main() { let t1 = std::thread::spawn(move || { let b = buf.freeze(); + let tx_socket = tx.udp_socket.as_ref().unwrap(); + println!("START: {:?}", Instant::now()); for i in 0..num_pkts { - tx.network.udp_write_broadcast(BroadcastMsg { + tx_socket.write_broadcast(BroadcastMsg { targets: vec![tx.target], payload: b.slice(i * pkt_size..(i + 1) * pkt_size), stride: DEFAULT_SEGMENT_SIZE, }) } - tx.network.tcp_write( + tx.dataplane.tcp_write( tx.target, TcpMsg { msg: Bytes::from(&b"Hello world"[..]), @@ -107,14 +112,18 @@ fn main() { } struct Node { - network: Dataplane, + dataplane: Dataplane, + udp_socket: Option, target: SocketAddr, } impl Node { pub fn new(addr: &SocketAddr, target_addr: &str) -> Self { + let mut dataplane = DataplaneBuilder::new(addr, 1_000).build(); + let udp_socket = dataplane.take_udp_socket_handle(LEGACY_SOCKET); Self { - network: DataplaneBuilder::new(addr, 1_000).build(), + dataplane, + udp_socket, target: target_addr.parse().unwrap(), } } @@ -129,8 +138,10 @@ impl Stream for Node { ) -> Poll> { let this = self.deref_mut(); - if let Poll::Ready(message) = pin!(this.network.udp_read()).poll_unpin(cx) { - return Poll::Ready(Some(message)); + if let Some(socket) = &mut this.udp_socket { + if let Poll::Ready(message) = pin!(socket.recv()).poll_unpin(cx) { + return Poll::Ready(Some(message)); + } } Poll::Pending } diff --git a/monad-dataplane/src/lib.rs b/monad-dataplane/src/lib.rs index 989d02118a..b7089c7d6b 100644 --- a/monad-dataplane/src/lib.rs +++ b/monad-dataplane/src/lib.rs @@ -41,6 +41,11 @@ pub mod udp; pub(crate) use udp::UdpMessageType; +pub struct UdpSocketConfig { + pub socket_addr: SocketAddr, + pub label: String, +} + pub struct DataplaneBuilder { local_addr: SocketAddr, trusted_addresses: Vec, @@ -49,7 +54,7 @@ pub struct DataplaneBuilder { udp_buffer_size: Option, tcp_config: TcpConfig, ban_duration: Duration, - direct_socket_port: Option, + udp_sockets: Vec, } impl DataplaneBuilder { @@ -67,27 +72,25 @@ impl DataplaneBuilder { connections_limit: 10000, per_ip_connections_limit: 100, }, - ban_duration: Duration::from_secs(5 * 60), // 5 minutes - direct_socket_port: None, + ban_duration: Duration::from_secs(5 * 60), + udp_sockets: vec![UdpSocketConfig { + socket_addr: *local_addr, + label: "legacy".to_string(), + }], } } - /// with_udp_buffer_size sets the buffer size for udp socket that is managed by dataplane - /// to a requested value. pub fn with_udp_buffer_size(mut self, buffer_size: usize) -> Self { self.udp_buffer_size = Some(buffer_size); self } - /// with_tcp_connections_limit sets total and per_ip connection limit. if per_ip is zero it will be - /// equal to total. pub fn with_tcp_connections_limit(mut self, total: usize, per_ip: usize) -> Self { self.tcp_config.connections_limit = total; self.tcp_config.per_ip_connections_limit = if per_ip == 0 { total } else { per_ip }; self } - /// with_tcp_rps_burst sets the rate limit and burst for tcp connections. pub fn with_tcp_rps_burst(mut self, rps: u32, burst: u32) -> Self { self.tcp_config.rate_limit.rps = NonZeroU32::new(rps).expect("rps must be non-zero"); self.tcp_config.rate_limit.rps_burst = @@ -95,15 +98,13 @@ impl DataplaneBuilder { self } - /// with trusted_ips sets the list of trusted ip addresses. pub fn with_trusted_ips(mut self, ips: Vec) -> Self { self.trusted_addresses = ips; self } - /// with_direct_socket configures an additional UDP socket for direct peer communication - pub fn with_direct_socket(mut self, port: u16) -> Self { - self.direct_socket_port = Some(port); + pub fn extend_udp_sockets(mut self, sockets: Vec) -> Self { + self.udp_sockets.extend(sockets); self } @@ -115,15 +116,25 @@ impl DataplaneBuilder { trusted_addresses: trusted, tcp_config, ban_duration, - direct_socket_port, + udp_sockets, } = self; let (tcp_ingress_tx, tcp_ingress_rx) = mpsc::channel(TCP_INGRESS_CHANNEL_SIZE); let (tcp_egress_tx, tcp_egress_rx) = mpsc::channel(TCP_EGRESS_CHANNEL_SIZE); - let (udp_ingress_tx, udp_ingress_rx) = mpsc::channel(UDP_INGRESS_CHANNEL_SIZE); - let (udp_direct_ingress_tx, udp_direct_ingress_rx) = - mpsc::channel(UDP_INGRESS_CHANNEL_SIZE); - let (udp_egress_tx, udp_egress_rx) = mpsc::channel(UDP_EGRESS_CHANNEL_SIZE); + + let (udp_multi_egress_tx, udp_multi_egress_rx) = mpsc::channel(UDP_EGRESS_CHANNEL_SIZE); + + let mut udp_socket_handles = Vec::new(); + let mut socket_configs = Vec::new(); + + for (socket_id, UdpSocketConfig { socket_addr, label }) in + udp_sockets.into_iter().enumerate() + { + let (handle, config) = + create_socket_handle(socket_id, socket_addr, label, udp_multi_egress_tx.clone()); + udp_socket_handles.push(Some(handle)); + socket_configs.push(config); + } let ready = Arc::new(AtomicBool::new(false)); let ready_clone = ready.clone(); @@ -156,12 +167,9 @@ impl DataplaneBuilder { tcp_ingress_tx, tcp_egress_rx, ); - udp::spawn_tasks( - local_addr, - direct_socket_port, - udp_ingress_tx, - udp_direct_ingress_tx, - udp_egress_rx, + udp::spawn_multi_socket_tasks( + socket_configs, + udp_multi_egress_rx, up_bandwidth_mbps, udp_buffer_size, ); @@ -174,19 +182,14 @@ impl DataplaneBuilder { }) .expect("failed to spawn dataplane thread"); - let writer = DataplaneWriter::new( - tcp_egress_tx, - udp_egress_tx, - tcp_control_map, - banned_ips_tx, - addrlist, - ); - let reader = DataplaneReader::new(tcp_ingress_rx, udp_ingress_rx, udp_direct_ingress_rx); + let writer = DataplaneWriter::new(tcp_egress_tx, tcp_control_map, banned_ips_tx, addrlist); + let reader = DataplaneReader::new(tcp_ingress_rx); Dataplane { writer, reader, ready, + udp_socket_handles, } } } @@ -195,12 +198,224 @@ pub struct Dataplane { writer: DataplaneWriter, reader: DataplaneReader, ready: Arc, + udp_socket_handles: Vec>, +} + +pub struct UdpSocketReader { + socket_id: usize, + label: String, + ingress_rx: mpsc::Receiver, +} + +impl UdpSocketReader { + pub async fn recv(&mut self) -> RecvUdpMsg { + self.ingress_rx.recv().await.unwrap_or_else(|| { + panic!( + "socket {} ({}) ingress channel closed", + self.socket_id, self.label + ) + }) + } +} + +#[derive(Clone)] +pub struct UdpSocketWriter { + socket_id: usize, + socket_addr: SocketAddr, + label: String, + egress_tx: mpsc::Sender<(usize, SocketAddr, UdpMsg)>, + msgs_dropped: Arc, +} + +pub struct UdpSocketHandle { + reader: UdpSocketReader, + writer: UdpSocketWriter, +} + +impl UdpSocketHandle { + pub fn split(self) -> (UdpSocketReader, UdpSocketWriter) { + (self.reader, self.writer) + } + + pub async fn recv(&mut self) -> RecvUdpMsg { + self.reader.recv().await + } + + pub fn write(&self, dst: SocketAddr, payload: Bytes, stride: u16) { + self.writer.write(dst, payload, stride) + } + + pub fn write_broadcast(&self, msg: BroadcastMsg) { + self.writer.write_broadcast(msg) + } + + pub fn write_broadcast_with_priority(&self, msg: BroadcastMsg, priority: UdpPriority) { + self.writer.write_broadcast_with_priority(msg, priority) + } + + pub fn write_unicast(&self, msg: UnicastMsg) { + self.writer.write_unicast(msg) + } + + pub fn write_unicast_with_priority(&self, msg: UnicastMsg, priority: UdpPriority) { + self.writer.write_unicast_with_priority(msg, priority) + } + + pub fn writer(&self) -> &UdpSocketWriter { + &self.writer + } + + pub fn label(&self) -> &str { + &self.writer.label + } +} + +impl UdpSocketWriter { + pub fn write(&self, dst: SocketAddr, payload: Bytes, stride: u16) { + let msg_length = payload.len(); + let result = self.egress_tx.try_send(( + self.socket_id, + dst, + UdpMsg { + payload, + stride, + msg_type: UdpMessageType::Broadcast, + priority: UdpPriority::Regular, + }, + )); + + match result { + Ok(()) => {} + Err(TrySendError::Full(_)) => { + let total = self.msgs_dropped.fetch_add(1, Ordering::Relaxed); + warn!( + socket_id = self.socket_id, + label = %self.label, + ?dst, + msg_length, + total_msgs_dropped = total, + "udp egress channel full, dropping message" + ); + } + Err(TrySendError::Closed(_)) => { + panic!( + "socket {} ({}) egress channel closed", + self.socket_id, self.label + ) + } + } + } + + pub fn write_broadcast(&self, msg: BroadcastMsg) { + self.write_broadcast_with_priority(msg, UdpPriority::Regular); + } + + pub fn write_broadcast_with_priority(&self, msg: BroadcastMsg, priority: UdpPriority) { + let msg_len = msg.payload.len(); + let mut pending_count = msg.msg_count(); + + for (dst, udp_msg) in msg.into_iter_with_priority(priority) { + match self.egress_tx.try_send((self.socket_id, dst, udp_msg)) { + Ok(()) => pending_count -= 1, + Err(TrySendError::Full(_)) => break, + Err(TrySendError::Closed(_)) => { + panic!( + "socket {} ({}) egress channel closed", + self.socket_id, self.label + ) + } + } + } + + if pending_count > 0 { + let total = self + .msgs_dropped + .fetch_add(pending_count, Ordering::Relaxed); + warn!( + socket_id = self.socket_id, + label = %self.label, + num_msgs_dropped = pending_count, + total_msgs_dropped = total, + msg_length = msg_len, + ?priority, + "udp egress channel full, dropping broadcast messages" + ); + } + } + + pub fn write_unicast(&self, msg: UnicastMsg) { + self.write_unicast_with_priority(msg, UdpPriority::Regular); + } + + pub fn write_unicast_with_priority(&self, msg: UnicastMsg, priority: UdpPriority) { + let mut pending_count = msg.msg_count(); + + for (dst, udp_msg) in msg.into_iter_with_priority(priority) { + match self.egress_tx.try_send((self.socket_id, dst, udp_msg)) { + Ok(()) => pending_count -= 1, + Err(TrySendError::Full(_)) => break, + Err(TrySendError::Closed(_)) => { + panic!( + "socket {} ({}) egress channel closed", + self.socket_id, self.label + ) + } + } + } + + if pending_count > 0 { + let total = self + .msgs_dropped + .fetch_add(pending_count, Ordering::Relaxed); + warn!( + socket_id = self.socket_id, + label = %self.label, + num_msgs_dropped = pending_count, + total_msgs_dropped = total, + ?priority, + "udp egress channel full, dropping unicast messages" + ); + } + } +} + +impl std::fmt::Debug for UdpSocketHandle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("UdpSocketHandle") + .field("socket_id", &self.writer.socket_id) + .field("label", &self.writer.label) + .field("socket_addr", &self.writer.socket_addr) + .finish() + } +} + +pub struct UdpDataplane { + socket_handles: Vec>, +} + +impl UdpDataplane { + pub fn socket_handles(&mut self) -> &mut [Option] { + &mut self.socket_handles + } + + pub fn take_socket(&mut self, label: &str) -> Option { + self.socket_handles + .iter_mut() + .find_map(|h| match h.as_ref()?.label() == label { + true => h.take(), + false => None, + }) + } + + pub fn take_socket_by_id(&mut self, socket_id: usize) -> Option { + self.socket_handles + .get_mut(socket_id) + .and_then(Option::take) + } } pub struct DataplaneReader { tcp_ingress_rx: mpsc::Receiver, - udp_ingress_rx: mpsc::Receiver, - udp_direct_ingress_rx: mpsc::Receiver, } #[derive(Clone)] @@ -210,14 +425,12 @@ pub struct DataplaneWriter { struct DataplaneWriterInner { tcp_egress_tx: mpsc::Sender<(SocketAddr, TcpMsg)>, - udp_egress_tx: mpsc::Sender, tcp_control_map: TcpControl, notify_ban_expiry: mpsc::UnboundedSender<(IpAddr, Instant)>, addrlist: Arc, tcp_msgs_dropped: AtomicUsize, - udp_msgs_dropped: AtomicUsize, } #[derive(Clone)] @@ -232,18 +445,22 @@ impl BroadcastMsg { self.targets.len() } - fn into_iter_with_priority(self, priority: UdpPriority) -> impl Iterator { + fn into_iter_with_priority(self, priority: UdpPriority) -> impl Iterator { let Self { targets, payload, stride, } = self; - targets.into_iter().map(move |dst| UdpMsg { - dst, - payload: payload.clone(), - stride, - msg_type: UdpMessageType::Broadcast, - priority, + targets.into_iter().map(move |dst| { + ( + dst, + UdpMsg { + payload: payload.clone(), + stride, + msg_type: UdpMessageType::Broadcast, + priority, + }, + ) }) } } @@ -259,14 +476,18 @@ impl UnicastMsg { self.msgs.len() } - fn into_iter_with_priority(self, priority: UdpPriority) -> impl Iterator { + fn into_iter_with_priority(self, priority: UdpPriority) -> impl Iterator { let Self { msgs, stride } = self; - msgs.into_iter().map(move |(dst, payload)| UdpMsg { - dst, - payload, - stride, - msg_type: UdpMessageType::Broadcast, - priority, + msgs.into_iter().map(move |(dst, payload)| { + ( + dst, + UdpMsg { + payload, + stride, + msg_type: UdpMessageType::Broadcast, + priority, + }, + ) }) } } @@ -290,7 +511,6 @@ pub struct TcpMsg { } pub(crate) struct UdpMsg { - pub(crate) dst: SocketAddr, pub(crate) payload: Bytes, pub(crate) stride: u16, pub(crate) msg_type: UdpMessageType, @@ -302,9 +522,56 @@ const TCP_EGRESS_CHANNEL_SIZE: usize = 1024; const UDP_INGRESS_CHANNEL_SIZE: usize = 12_800; const UDP_EGRESS_CHANNEL_SIZE: usize = 12_800; +fn create_socket_handle( + socket_id: usize, + socket_addr: SocketAddr, + label: String, + egress_tx: mpsc::Sender<(usize, SocketAddr, UdpMsg)>, +) -> ( + UdpSocketHandle, + (usize, SocketAddr, String, mpsc::Sender), +) { + let (ingress_tx, ingress_rx) = mpsc::channel(UDP_INGRESS_CHANNEL_SIZE); + let msgs_dropped = Arc::new(AtomicUsize::new(0)); + + let reader = UdpSocketReader { + socket_id, + label: label.clone(), + ingress_rx, + }; + + let writer = UdpSocketWriter { + socket_id, + socket_addr, + label: label.clone(), + egress_tx, + msgs_dropped, + }; + + let handle = UdpSocketHandle { reader, writer }; + let config = (socket_id, socket_addr, label, ingress_tx); + (handle, config) +} + impl Dataplane { - pub fn split(self) -> (DataplaneReader, DataplaneWriter) { - (self.reader, self.writer) + pub fn split(self) -> (DataplaneReader, DataplaneWriter, UdpDataplane) { + let udp = UdpDataplane { + socket_handles: self.udp_socket_handles, + }; + (self.reader, self.writer, udp) + } + + pub fn udp_socket_handles(&mut self) -> &mut [Option] { + &mut self.udp_socket_handles + } + + pub fn take_udp_socket_handle(&mut self, label: &str) -> Option { + self.udp_socket_handles + .iter_mut() + .find_map(|h| match h.as_ref()?.label() == label { + true => h.take(), + false => None, + }) } /// add_trusted marks ip address as trusted. @@ -345,34 +612,6 @@ impl Dataplane { self.writer.tcp_write(addr, msg) } - pub async fn udp_read(&mut self) -> RecvUdpMsg { - self.reader.udp_read().await - } - - pub async fn udp_direct_read(&mut self) -> RecvUdpMsg { - self.reader.udp_direct_read().await - } - - pub fn udp_write_broadcast(&self, msg: BroadcastMsg) { - self.writer.udp_write_broadcast(msg); - } - - pub fn udp_write_broadcast_with_priority(&self, msg: BroadcastMsg, priority: UdpPriority) { - self.writer.udp_write_broadcast_with_priority(msg, priority); - } - - pub fn udp_write_unicast(&self, msg: UnicastMsg) { - self.writer.udp_write_unicast(msg); - } - - pub fn udp_write_direct(&self, dst: SocketAddr, payload: Bytes, stride: u16) { - self.writer.udp_write_direct(dst, payload, stride); - } - - pub fn udp_write_unicast_with_priority(&self, msg: UnicastMsg, priority: UdpPriority) { - self.writer.udp_write_unicast_with_priority(msg, priority); - } - pub fn ready(&self) -> bool { self.ready.load(Ordering::Acquire) } @@ -390,16 +629,8 @@ impl Dataplane { } impl DataplaneReader { - fn new( - tcp_ingress_rx: mpsc::Receiver, - udp_ingress_rx: mpsc::Receiver, - udp_direct_ingress_rx: mpsc::Receiver, - ) -> Self { - Self { - tcp_ingress_rx, - udp_ingress_rx, - udp_direct_ingress_rx, - } + fn new(tcp_ingress_rx: mpsc::Receiver) -> Self { + Self { tcp_ingress_rx } } pub async fn tcp_read(&mut self) -> RecvTcpMsg { @@ -409,36 +640,12 @@ impl DataplaneReader { } } - pub async fn udp_read(&mut self) -> RecvUdpMsg { - match self.udp_ingress_rx.recv().await { - Some(msg) => msg, - None => panic!("udp_ingress_rx channel closed"), - } - } - - pub async fn udp_direct_read(&mut self) -> RecvUdpMsg { - match self.udp_direct_ingress_rx.recv().await { - Some(msg) => msg, - None => panic!("udp_direct_ingress_rx channel closed"), - } - } - - pub fn split(self) -> (TcpReader, UdpReader) { - ( - TcpReader(self.tcp_ingress_rx), - UdpReader { - udp_ingress_rx: self.udp_ingress_rx, - udp_direct_ingress_rx: self.udp_direct_ingress_rx, - }, - ) + pub fn split(self) -> TcpReader { + TcpReader(self.tcp_ingress_rx) } } pub struct TcpReader(mpsc::Receiver); -pub struct UdpReader { - udp_ingress_rx: mpsc::Receiver, - udp_direct_ingress_rx: mpsc::Receiver, -} impl TcpReader { pub async fn read(&mut self) -> RecvTcpMsg { @@ -449,38 +656,19 @@ impl TcpReader { } } -impl UdpReader { - pub async fn read(&mut self) -> RecvUdpMsg { - match self.udp_ingress_rx.recv().await { - Some(msg) => msg, - None => panic!("udp_ingress_rx channel closed"), - } - } - - pub async fn direct_read(&mut self) -> RecvUdpMsg { - match self.udp_direct_ingress_rx.recv().await { - Some(msg) => msg, - None => panic!("udp_direct_ingress_rx channel closed"), - } - } -} - impl DataplaneWriter { fn new( tcp_egress_tx: mpsc::Sender<(SocketAddr, TcpMsg)>, - udp_egress_tx: mpsc::Sender, tcp_control_map: TcpControl, notify_ban_expiry: mpsc::UnboundedSender<(IpAddr, Instant)>, addrlist: Arc, ) -> Self { let inner = DataplaneWriterInner { tcp_egress_tx, - udp_egress_tx, tcp_control_map, notify_ban_expiry, addrlist, tcp_msgs_dropped: AtomicUsize::new(0), - udp_msgs_dropped: AtomicUsize::new(0), }; Self { inner: Arc::new(inner), @@ -507,90 +695,6 @@ impl DataplaneWriter { } } - pub fn udp_write_broadcast(&self, msg: BroadcastMsg) { - self.udp_write_broadcast_with_priority(msg, UdpPriority::Regular); - } - - pub fn udp_write_unicast(&self, msg: UnicastMsg) { - self.udp_write_unicast_with_priority(msg, UdpPriority::Regular); - } - - #[tracing::instrument( - level="trace", - skip_all, - fields(len = msg.payload.len(), targets = msg.targets.len(), priority = ?priority) - )] - pub fn udp_write_broadcast_with_priority(&self, msg: BroadcastMsg, priority: UdpPriority) { - let mut pending_count = msg.msg_count(); - let msg_len = msg.payload.len(); - - for udp_msg in msg.into_iter_with_priority(priority) { - match self.inner.udp_egress_tx.try_send(udp_msg) { - Ok(()) => { - pending_count -= 1; - } - Err(TrySendError::Full(_)) => { - break; - } - Err(TrySendError::Closed(_)) => panic!("udp_egress_tx channel closed"), - } - } - - if pending_count == 0 { - return; - } - - let udp_msgs_dropped = self - .inner - .udp_msgs_dropped - .fetch_add(pending_count, Ordering::Relaxed); - - warn!( - num_msgs_dropped = pending_count, - total_udp_msgs_dropped = udp_msgs_dropped, - msg_length = msg_len, - ?priority, - "udp_egress_tx channel full, dropping message" - ); - } - - #[tracing::instrument( - level="trace", - skip_all, - fields(msgs = msg.msgs.len(), priority = ?priority) - )] - pub fn udp_write_unicast_with_priority(&self, msg: UnicastMsg, priority: UdpPriority) { - let mut pending_count = msg.msg_count(); - - for udp_msg in msg.into_iter_with_priority(priority) { - match self.inner.udp_egress_tx.try_send(udp_msg) { - Ok(()) => { - pending_count -= 1; - } - Err(TrySendError::Full(_)) => { - break; - } - Err(TrySendError::Closed(_)) => panic!("udp_egress_tx channel closed"), - } - } - - if pending_count == 0 { - return; - } - - let udp_msgs_dropped = self - .inner - .udp_msgs_dropped - .fetch_add(pending_count, Ordering::Relaxed); - - warn!( - num_msgs_dropped = pending_count, - total_udp_msgs_dropped = udp_msgs_dropped, - ?priority, - "udp_egress_tx channel full, dropping message" - ); - } - /// add_trusted marks ip address as trusted. /// connections limits are not applied to trusted ips. pub fn add_trusted(&self, addr: IpAddr) { @@ -630,31 +734,4 @@ impl DataplaneWriter { .tcp_control_map .disconnect_socket(addr.ip(), addr.port()); } - - pub fn udp_write_direct(&self, dst: SocketAddr, payload: Bytes, stride: u16) { - let msg_length = payload.len(); - let udp_msg = UdpMsg { - dst, - payload, - stride, - msg_type: UdpMessageType::Direct, - priority: UdpPriority::Regular, - }; - - match self.inner.udp_egress_tx.try_send(udp_msg) { - Ok(()) => {} - Err(TrySendError::Full(_)) => { - let udp_msgs_dropped = self.inner.udp_msgs_dropped.fetch_add(1, Ordering::Relaxed); - - warn!( - num_msgs_dropped = 1, - total_udp_msgs_dropped = udp_msgs_dropped, - ?dst, - msg_length, - "udp_egress_tx channel full, dropping direct message" - ); - } - Err(TrySendError::Closed(_)) => panic!("udp_egress_tx channel closed"), - } - } } diff --git a/monad-dataplane/src/udp.rs b/monad-dataplane/src/udp.rs index d343982df7..f2a106875f 100644 --- a/monad-dataplane/src/udp.rs +++ b/monad-dataplane/src/udp.rs @@ -21,7 +21,7 @@ use std::{ time::{Duration, Instant}, }; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use futures::future::join_all; use monoio::{net::udp::UdpSocket, spawn, time}; use tokio::sync::mpsc; @@ -37,7 +37,7 @@ pub(crate) enum UdpMessageType { } struct PriorityQueues { - queues: [VecDeque; 2], + queues: [VecDeque<(usize, SocketAddr, UdpMsg)>; 2], } impl PriorityQueues { @@ -47,14 +47,14 @@ impl PriorityQueues { } } - fn push(&mut self, msg: UdpMsg) { - self.queues[msg.priority as usize].push_back(msg); + fn push(&mut self, socket_id: usize, addr: SocketAddr, msg: UdpMsg) { + self.queues[msg.priority as usize].push_back((socket_id, addr, msg)); } - fn pop_highest_priority(&mut self) -> Option { + fn pop_highest_priority(&mut self) -> Option<(usize, SocketAddr, UdpMsg)> { for queue in self.queues.iter_mut() { - if let Some(msg) = queue.pop_front() { - return Some(msg); + if let Some(item) = queue.pop_front() { + return Some(item); } } None @@ -131,37 +131,20 @@ fn set_mtu_discovery(socket: &UdpSocket) { } } -pub(crate) fn spawn_tasks( - local_addr: SocketAddr, - direct_socket_port: Option, - udp_ingress_tx: mpsc::Sender, - udp_direct_ingress_tx: mpsc::Sender, - udp_egress_rx: mpsc::Receiver, +pub(crate) fn spawn_multi_socket_tasks( + socket_configs: Vec<(usize, SocketAddr, String, mpsc::Sender)>, + udp_multi_egress_rx: mpsc::Receiver<(usize, SocketAddr, UdpMsg)>, up_bandwidth_mbps: u64, buffer_size: Option, ) { - let (udp_socket_rx, udp_socket_tx) = create_socket_pair(local_addr, buffer_size); - let (direct_socket_rx, direct_socket_tx) = direct_socket_port - .map(|port| { - let mut direct_addr = local_addr; - direct_addr.set_port(port); - let (rx, tx) = create_socket_pair(direct_addr, buffer_size); - (Some(rx), Some(tx)) - }) - .unwrap_or((None, None)); - - spawn(rx( - udp_socket_rx, - direct_socket_rx, - udp_ingress_tx, - udp_direct_ingress_tx, - )); - spawn(tx( - udp_socket_tx, - direct_socket_tx, - udp_egress_rx, - up_bandwidth_mbps, - )); + if !socket_configs.is_empty() { + spawn(multi_socket_task( + socket_configs, + udp_multi_egress_rx, + up_bandwidth_mbps, + buffer_size, + )); + } } fn create_socket_pair(addr: SocketAddr, buffer_size: Option) -> (UdpSocket, UdpSocket) { @@ -172,23 +155,6 @@ fn create_socket_pair(addr: SocketAddr, buffer_size: Option) -> (UdpSocke (rx, tx) } -async fn rx( - udp_socket_rx: UdpSocket, - direct_socket_rx: Option, - udp_ingress_tx: mpsc::Sender, - udp_direct_ingress_tx: mpsc::Sender, -) { - match direct_socket_rx { - Some(direct_socket) => { - spawn(rx_single_socket(udp_socket_rx, udp_ingress_tx)); - spawn(rx_single_socket(direct_socket, udp_direct_ingress_tx)); - } - None => { - rx_single_socket(udp_socket_rx, udp_ingress_tx).await; - } - } -} - async fn rx_single_socket(socket: UdpSocket, udp_ingress_tx: mpsc::Sender) { loop { let buf = BytesMut::with_capacity(ETHERNET_SEGMENT_SIZE.into()); @@ -217,18 +183,42 @@ async fn rx_single_socket(socket: UdpSocket, udp_ingress_tx: mpsc::Sender, - mut udp_egress_rx: mpsc::Receiver, +const MAX_AGGREGATED_WRITE_SIZE: u16 = 65535 - IPV4_HDR_SIZE - UDP_HDR_SIZE; +const MAX_AGGREGATED_SEGMENTS: u16 = 128; + +fn max_write_size_for_segment_size(segment_size: u16) -> u16 { + (MAX_AGGREGATED_WRITE_SIZE / segment_size).min(MAX_AGGREGATED_SEGMENTS) * segment_size +} + +fn is_eafnosupport(err: &Error) -> bool { + const EAFNOSUPPORT: &str = "Address family not supported by protocol"; + + let err = format!("{}", err); + + err.len() >= EAFNOSUPPORT.len() && &err[0..EAFNOSUPPORT.len()] == EAFNOSUPPORT +} + +async fn multi_socket_task( + socket_configs: Vec<(usize, SocketAddr, String, mpsc::Sender)>, + mut udp_multi_egress_rx: mpsc::Receiver<(usize, SocketAddr, UdpMsg)>, up_bandwidth_mbps: u64, + buffer_size: Option, ) { - let mut next_transmit = Instant::now(); + let mut sockets_tx: Vec> = socket_configs + .into_iter() + .map(|(socket_id, socket_addr, label, ingress_tx)| { + let (socket_rx, socket_tx) = create_socket_pair(socket_addr, buffer_size); + spawn(rx_single_socket(socket_rx, ingress_tx)); + trace!(socket_id, label = %label, ?socket_addr, "created socket"); + Some(socket_tx) + }) + .collect(); + sockets_tx.resize_with(sockets_tx.len(), || None); + let mut next_transmit = Instant::now(); let mut priority_queues = PriorityQueues::new(); - let max_batch_bytes = max_write_size_for_segment_size(DEFAULT_SEGMENT_SIZE) as usize; - let mut send_futures = Vec::with_capacity(MAX_AGGREGATED_SEGMENTS as usize); + let mut send_futures = Vec::new(); loop { let now = Instant::now(); @@ -236,17 +226,17 @@ async fn tx( time::sleep(next_transmit - now).await; } else { let late = now - next_transmit; - if late > PACING_SLEEP_OVERSHOOT_DETECTION_WINDOW { next_transmit = now; } } - if fill_message_queues(&mut udp_egress_rx, &mut priority_queues) - .await - .is_err() - { - return; + while priority_queues.is_empty() || !udp_multi_egress_rx.is_empty() { + let Some((socket_id, addr, udp_msg)) = udp_multi_egress_rx.recv().await else { + return; + }; + + priority_queues.push(socket_id, addr, udp_msg); } let queue_len = priority_queues @@ -262,7 +252,7 @@ async fn tx( && total_bytes < max_batch_bytes && batch_count < MAX_AGGREGATED_SEGMENTS as usize { - let mut msg = priority_queues.pop_highest_priority().unwrap(); + let (socket_id, addr, mut msg) = priority_queues.pop_highest_priority().unwrap(); let chunk_size = msg .payload .len() @@ -270,40 +260,36 @@ async fn tx( .min(max_batch_bytes); if chunk_size + total_bytes > max_batch_bytes { - priority_queues.push(msg); + priority_queues.push(socket_id, addr, msg); break; } let chunk = msg.payload.split_to(chunk_size); total_bytes += chunk.len(); - let socket = match (&msg.msg_type, &direct_socket_tx) { - (UdpMessageType::Direct, Some(direct_socket)) => direct_socket, - _ => &socket_tx, - }; - - let dst = msg.dst; - let msg_type = msg.msg_type; + if let Some(Some(socket)) = sockets_tx.get(socket_id) { + if !msg.payload.is_empty() { + priority_queues.push(socket_id, addr, msg); + } - if !msg.payload.is_empty() { - priority_queues.push(msg); + trace!( + socket_id, + dst_addr = ?addr, + chunk_len = chunk.len(), + "preparing udp send" + ); + + send_futures.push(socket.send_to(chunk, addr)); + batch_count += 1; + } else { + warn!(socket_id, "invalid socket_id, dropping message"); } - - trace!( - dst_addr = ?dst, - chunk_len = chunk.len(), - msg_type = ?msg_type, - "preparing udp send" - ); - - send_futures.push(socket.send_to(chunk, dst)); - batch_count += 1; } if batch_count > 1 { trace!( batch_size = batch_count, - total_bytes = total_bytes, + total_bytes, queue_size = queue_len, "sending udp batch" ); @@ -339,34 +325,3 @@ async fn tx( } } } - -async fn fill_message_queues( - udp_egress_rx: &mut mpsc::Receiver, - priority_queues: &mut PriorityQueues, -) -> Result<(), ()> { - while priority_queues.is_empty() || !udp_egress_rx.is_empty() { - match udp_egress_rx.recv().await { - Some(udp_msg) => { - priority_queues.push(udp_msg); - } - None => return Err(()), - } - } - Ok(()) -} - -const MAX_AGGREGATED_WRITE_SIZE: u16 = 65535 - IPV4_HDR_SIZE - UDP_HDR_SIZE; -const MAX_AGGREGATED_SEGMENTS: u16 = 128; - -fn max_write_size_for_segment_size(segment_size: u16) -> u16 { - (MAX_AGGREGATED_WRITE_SIZE / segment_size).min(MAX_AGGREGATED_SEGMENTS) * segment_size -} - -// This is very very ugly, but there is no other way to figure this out. -fn is_eafnosupport(err: &Error) -> bool { - const EAFNOSUPPORT: &str = "Address family not supported by protocol"; - - let err = format!("{}", err); - - err.len() >= EAFNOSUPPORT.len() && &err[0..EAFNOSUPPORT.len()] == EAFNOSUPPORT -} diff --git a/monad-dataplane/tests/address_family_mismatch.rs b/monad-dataplane/tests/address_family_mismatch.rs index 6be8c01235..feb74ad864 100644 --- a/monad-dataplane/tests/address_family_mismatch.rs +++ b/monad-dataplane/tests/address_family_mismatch.rs @@ -21,6 +21,8 @@ use tracing::debug; /// 1_000 = 1 Gbps, 10_000 = 10 Gbps const UP_BANDWIDTH_MBPS: u64 = 1_000; +const LEGACY_SOCKET: &str = "legacy"; + const BIND_ADDRS: [&str; 2] = ["0.0.0.0:19100", "127.0.0.1:19101"]; const TX_ADDRS: [&str; 2] = ["127.0.0.1:19200", "[::1]:19201"]; @@ -40,22 +42,23 @@ fn address_family_mismatch() { })); for addr in BIND_ADDRS { - let dataplane = DataplaneBuilder::new(&addr.parse().unwrap(), UP_BANDWIDTH_MBPS).build(); + let mut dataplane = + DataplaneBuilder::new(&addr.parse().unwrap(), UP_BANDWIDTH_MBPS).build(); - // Allow Dataplane thread to set itself up. assert!(dataplane.block_until_ready(Duration::from_secs(1))); + let socket = dataplane.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + for tx_addr in TX_ADDRS { debug!("sending to {} from {}", tx_addr, addr); - dataplane.udp_write_broadcast(BroadcastMsg { + socket.write_broadcast(BroadcastMsg { targets: vec![tx_addr.parse().unwrap(); 1], payload: vec![0; DEFAULT_SEGMENT_SIZE.into()].into(), stride: DEFAULT_SEGMENT_SIZE, }); } - // Allow Dataplane thread to catch up. sleep(Duration::from_millis(10)); } } diff --git a/monad-dataplane/tests/tests.rs b/monad-dataplane/tests/tests.rs index aca026db08..0aa29c9730 100644 --- a/monad-dataplane/tests/tests.rs +++ b/monad-dataplane/tests/tests.rs @@ -34,9 +34,11 @@ use rand::Rng; use rstest::*; use tracing_subscriber::fmt::format::FmtSpan; -/// 1_000 = 1 Gbps, 10_000 = 10 Gbps const UP_BANDWIDTH_MBPS: u64 = 1_000; +const LEGACY_SOCKET: &str = "legacy"; +const DIRECT_SOCKET: &str = "direct"; + static ONCE_SETUP: Once = Once::new(); fn once_setup() { @@ -58,9 +60,8 @@ fn udp_broadcast() { let num_msgs = 10; let mut rx = DataplaneBuilder::new(&rx_addr, UP_BANDWIDTH_MBPS).build(); - let tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); - // Allow Dataplane threads to set themselves up. assert!(rx.block_until_ready(Duration::from_secs(1))); assert!(tx.block_until_ready(Duration::from_secs(1))); @@ -68,14 +69,17 @@ fn udp_broadcast() { .map(|_| rand::thread_rng().gen_range(0..255)) .collect(); - tx.udp_write_broadcast(BroadcastMsg { + let mut rx_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + + tx_socket.write_broadcast(BroadcastMsg { targets: vec![rx_addr; num_msgs], payload: payload.clone().into(), stride: DEFAULT_SEGMENT_SIZE, }); for _ in 0..num_msgs { - let msg: RecvUdpMsg = executor::block_on(rx.udp_read()); + let msg: RecvUdpMsg = executor::block_on(rx_socket.recv()); assert_eq!(msg.src_addr, tx_addr); assert_eq!(msg.payload, payload); @@ -92,9 +96,8 @@ fn udp_unicast() { let num_msgs = 10; let mut rx = DataplaneBuilder::new(&rx_addr, UP_BANDWIDTH_MBPS).build(); - let tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); - // Allow Dataplane threads to set themselves up. assert!(rx.block_until_ready(Duration::from_secs(1))); assert!(tx.block_until_ready(Duration::from_secs(1))); @@ -102,13 +105,16 @@ fn udp_unicast() { .map(|_| rand::thread_rng().gen_range(0..255)) .collect(); - tx.udp_write_unicast(UnicastMsg { + let mut rx_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + + tx_socket.write_unicast(UnicastMsg { msgs: vec![(rx_addr, payload.clone().into()); num_msgs], stride: DEFAULT_SEGMENT_SIZE, }); for _ in 0..num_msgs { - let msg: RecvUdpMsg = executor::block_on(rx.udp_read()); + let msg: RecvUdpMsg = executor::block_on(rx_socket.recv()); assert_eq!(msg.src_addr, tx_addr); assert_eq!(msg.payload, payload); @@ -120,17 +126,28 @@ fn udp_unicast() { fn udp_direct_socket() { once_setup(); - let rx_addr = "127.0.0.1:9030".parse().unwrap(); + let rx_addr: std::net::SocketAddr = "127.0.0.1:9030".parse().unwrap(); let rx_direct_port = 9031; - let tx_addr = "127.0.0.1:9032".parse().unwrap(); + let tx_addr: std::net::SocketAddr = "127.0.0.1:9032".parse().unwrap(); let tx_direct_port = 9033; let num_msgs = 10; + let mut rx_direct_addr = rx_addr; + rx_direct_addr.set_port(rx_direct_port); + let mut tx_direct_addr = tx_addr; + tx_direct_addr.set_port(tx_direct_port); + let mut rx = DataplaneBuilder::new(&rx_addr, UP_BANDWIDTH_MBPS) - .with_direct_socket(rx_direct_port) + .extend_udp_sockets(vec![monad_dataplane::UdpSocketConfig { + socket_addr: rx_direct_addr, + label: DIRECT_SOCKET.to_string(), + }]) .build(); - let tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS) - .with_direct_socket(tx_direct_port) + let mut tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS) + .extend_udp_sockets(vec![monad_dataplane::UdpSocketConfig { + socket_addr: tx_direct_addr, + label: DIRECT_SOCKET.to_string(), + }]) .build(); assert!(rx.block_until_ready(Duration::from_secs(2))); @@ -140,27 +157,29 @@ fn udp_direct_socket() { .map(|_| rand::thread_rng().gen_range(0..255)) .collect(); - tx.udp_write_broadcast(BroadcastMsg { + let mut rx_legacy_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + let mut rx_direct_socket = rx.take_udp_socket_handle(DIRECT_SOCKET).unwrap(); + let tx_legacy_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + let tx_direct_socket = tx.take_udp_socket_handle(DIRECT_SOCKET).unwrap(); + + tx_legacy_socket.write_broadcast(BroadcastMsg { targets: vec![rx_addr; num_msgs / 2], payload: payload.clone().into(), stride: DEFAULT_SEGMENT_SIZE, }); - let mut rx_direct_addr = rx_addr; - rx_direct_addr.set_port(rx_direct_port); - for _ in 0..num_msgs / 2 { - tx.udp_write_direct(rx_direct_addr, payload.clone().into(), DEFAULT_SEGMENT_SIZE); + tx_direct_socket.write(rx_direct_addr, payload.clone().into(), DEFAULT_SEGMENT_SIZE); } for _ in 0..num_msgs / 2 { - let msg: RecvUdpMsg = executor::block_on(rx.udp_read()); + let msg: RecvUdpMsg = executor::block_on(rx_legacy_socket.recv()); assert_eq!(msg.src_addr, tx_addr); assert_eq!(msg.payload, payload); } for _ in 0..num_msgs / 2 { - let msg: RecvUdpMsg = executor::block_on(rx.udp_direct_read()); + let msg: RecvUdpMsg = executor::block_on(rx_direct_socket.recv()); assert_eq!(msg.src_addr.ip(), tx_addr.ip()); assert_eq!(msg.payload, payload); } @@ -504,9 +523,8 @@ fn broadcast_all_strides() { let mut rx = DataplaneBuilder::new(&rx_addr, UP_BANDWIDTH_MBPS) .with_udp_buffer_size(400 << 10) .build(); - let tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); - // Allow Dataplane threads to set themselves up. assert!(rx.block_until_ready(Duration::from_secs(1))); assert!(tx.block_until_ready(Duration::from_secs(1))); @@ -516,8 +534,11 @@ fn broadcast_all_strides() { .map(|_| rand::thread_rng().gen_range(0..255)) .collect(); + let mut rx_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + for stride in MINIMUM_SEGMENT_SIZE..=DEFAULT_SEGMENT_SIZE { - tx.udp_write_broadcast(BroadcastMsg { + tx_socket.write_broadcast(BroadcastMsg { targets: vec![rx_addr], payload: payload.clone().into(), stride, @@ -528,7 +549,7 @@ fn broadcast_all_strides() { let num_msgs = total_length.div_ceil(stride); for i in 0..num_msgs { - let msg: RecvUdpMsg = executor::block_on(rx.udp_read()); + let msg: RecvUdpMsg = executor::block_on(rx_socket.recv()); assert_eq!(msg.src_addr, tx_addr); assert_eq!( @@ -550,9 +571,8 @@ fn unicast_all_strides() { let mut rx = DataplaneBuilder::new(&rx_addr, UP_BANDWIDTH_MBPS) .with_udp_buffer_size(400 << 10) .build(); - let tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); - // Allow Dataplane threads to set themselves up. assert!(rx.block_until_ready(Duration::from_secs(1))); assert!(tx.block_until_ready(Duration::from_secs(1))); @@ -562,8 +582,11 @@ fn unicast_all_strides() { .map(|_| rand::thread_rng().gen_range(0..255)) .collect(); + let mut rx_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + for stride in MINIMUM_SEGMENT_SIZE..=DEFAULT_SEGMENT_SIZE { - tx.udp_write_unicast(UnicastMsg { + tx_socket.write_unicast(UnicastMsg { msgs: vec![(rx_addr, payload.clone().into())], stride, }); @@ -573,7 +596,7 @@ fn unicast_all_strides() { let num_msgs = total_length.div_ceil(stride); for i in 0..num_msgs { - let msg: RecvUdpMsg = executor::block_on(rx.udp_read()); + let msg: RecvUdpMsg = executor::block_on(rx_socket.recv()); assert_eq!(msg.src_addr, tx_addr); assert_eq!( @@ -793,14 +816,16 @@ fn udp_large_stride() { .set_read_timeout(Some(Duration::from_secs(1))) .unwrap(); - let tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, UP_BANDWIDTH_MBPS).build(); assert!(tx.block_until_ready(Duration::from_secs(1))); let payload: Vec = (0..65536) .map(|_| rand::thread_rng().gen_range(0..255)) .collect(); - tx.udp_write_broadcast(BroadcastMsg { + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + + tx_socket.write_broadcast(BroadcastMsg { targets: vec![rx_addr], payload: payload.clone().into(), stride: u16::MAX, @@ -834,7 +859,7 @@ fn udp_priority_delivery() { let low_bandwidth_mbps = 10; let mut rx = DataplaneBuilder::new(&rx_addr, low_bandwidth_mbps).build(); - let tx = DataplaneBuilder::new(&tx_addr, low_bandwidth_mbps).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, low_bandwidth_mbps).build(); assert!(rx.block_until_ready(Duration::from_secs(1))); assert!(tx.block_until_ready(Duration::from_secs(1))); @@ -846,10 +871,11 @@ fn udp_priority_delivery() { let expected_total_msgs = 2 * message_size.div_ceil(DEFAULT_SEGMENT_SIZE as usize); let (msg_tx, msg_rx) = mpsc::channel(); + let mut rx_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); let rx_handle = thread::spawn(move || { let mut messages = Vec::new(); loop { - let msg = executor::block_on(rx.udp_read()); + let msg = executor::block_on(rx_socket.recv()); messages.push(msg); if messages.len() == expected_total_msgs { msg_tx.send(messages).unwrap(); @@ -858,7 +884,8 @@ fn udp_priority_delivery() { } }); - tx.udp_write_unicast_with_priority( + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + tx_socket.write_unicast_with_priority( UnicastMsg { msgs: vec![(rx_addr, high_priority_data.into())], stride: DEFAULT_SEGMENT_SIZE, @@ -866,7 +893,7 @@ fn udp_priority_delivery() { UdpPriority::High, ); - tx.udp_write_unicast_with_priority( + tx_socket.write_unicast_with_priority( UnicastMsg { msgs: vec![(rx_addr, regular_priority_data.into())], stride: DEFAULT_SEGMENT_SIZE, @@ -923,7 +950,7 @@ fn udp_priority_with_regular_then_high_traffic() { let low_bandwidth_mbps = 10; let mut rx = DataplaneBuilder::new(&rx_addr, low_bandwidth_mbps).build(); - let tx = DataplaneBuilder::new(&tx_addr, low_bandwidth_mbps).build(); + let mut tx = DataplaneBuilder::new(&tx_addr, low_bandwidth_mbps).build(); assert!(rx.block_until_ready(Duration::from_secs(1))); assert!(tx.block_until_ready(Duration::from_secs(1))); @@ -936,10 +963,11 @@ fn udp_priority_with_regular_then_high_traffic() { let expected_total_msgs = 2 * num_msgs_per_mb; let (msg_tx, msg_rx) = mpsc::channel(); + let mut rx_socket = rx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); let rx_handle = thread::spawn(move || { let mut messages = Vec::new(); loop { - let msg = executor::block_on(rx.udp_read()); + let msg = executor::block_on(rx_socket.recv()); messages.push(msg); if messages.len() == expected_total_msgs { msg_tx.send(messages).unwrap(); @@ -948,7 +976,8 @@ fn udp_priority_with_regular_then_high_traffic() { } }); - tx.udp_write_unicast_with_priority( + let tx_socket = tx.take_udp_socket_handle(LEGACY_SOCKET).unwrap(); + tx_socket.write_unicast_with_priority( UnicastMsg { msgs: vec![(rx_addr, regular_priority_data.into())], stride: DEFAULT_SEGMENT_SIZE, @@ -958,7 +987,7 @@ fn udp_priority_with_regular_then_high_traffic() { thread::sleep(Duration::from_millis(50)); - tx.udp_write_unicast_with_priority( + tx_socket.write_unicast_with_priority( UnicastMsg { msgs: vec![(rx_addr, high_priority_data.into())], stride: DEFAULT_SEGMENT_SIZE, diff --git a/monad-raptorcast/src/lib.rs b/monad-raptorcast/src/lib.rs index f6708cd571..746df3f8d7 100644 --- a/monad-raptorcast/src/lib.rs +++ b/monad-raptorcast/src/lib.rs @@ -40,6 +40,7 @@ use monad_crypto::{ use monad_dataplane::{ udp::{segment_size_for_mtu, DEFAULT_MTU}, BroadcastMsg, DataplaneBuilder, DataplaneReader, DataplaneWriter, RecvTcpMsg, TcpMsg, + UdpSocketHandle, UnicastMsg, }; use monad_executor::{Executor, ExecutorMetrics, ExecutorMetricsChain}; use monad_executor_glue::{ @@ -70,6 +71,8 @@ pub mod util; const SIGNATURE_SIZE: usize = 65; +const RAPTORCAST_SOCKET: &str = "raptorcast"; + pub(crate) type OwnedMessageBuilder = packet::MessageBuilder<'static, ST, Arc>>>; @@ -83,7 +86,6 @@ where signing_key: Arc, is_dynamic_fullnode: bool, - // Raptorcast group with stake information. For the send side (i.e., initiating proposals) epoch_validators: BTreeMap>, rebroadcast_map: ReBroadcastGroupMap, @@ -97,6 +99,7 @@ where dataplane_reader: DataplaneReader, dataplane_writer: DataplaneWriter, + udp_socket: UdpSocketHandle, pending_events: VecDeque>, channel_to_secondary: Option>>, @@ -131,6 +134,7 @@ where secondary_mode: SecondaryRaptorCastModeConfig, dataplane_reader: DataplaneReader, dataplane_writer: DataplaneWriter, + udp_socket: UdpSocketHandle, peer_discovery_driver: Arc>>, current_epoch: Epoch, ) -> Self { @@ -173,6 +177,7 @@ where dataplane_reader, dataplane_writer, + udp_socket, pending_events: Default::default(), channel_to_secondary: None, channel_from_secondary: None, @@ -326,8 +331,7 @@ where .epoch_no(epoch) .build_unicast_msg(&outbound_message, &build_target) { - self.dataplane_writer - .udp_write_unicast_with_priority(rc_chunks, priority); + self.udp_socket.write_unicast_with_priority(rc_chunks, priority); } } @@ -356,8 +360,7 @@ where .message_builder .build_unicast_msg(&outbound_message, &build_target) { - self.dataplane_writer - .udp_write_unicast_with_priority(rc_chunks, priority); + self.udp_socket.write_unicast_with_priority(rc_chunks, priority); } } } @@ -399,9 +402,17 @@ where ..Default::default() }; let up_bandwidth_mbps = 1_000; - let dp = DataplaneBuilder::new(&local_addr, up_bandwidth_mbps).build(); + let mut dp = DataplaneBuilder::new(&local_addr, up_bandwidth_mbps) + .extend_udp_sockets(vec![monad_dataplane::UdpSocketConfig { + socket_addr: local_addr, + label: RAPTORCAST_SOCKET.to_string(), + }]) + .build(); assert!(dp.block_until_ready(Duration::from_secs(1))); - let (dp_reader, dp_writer) = dp.split(); + let udp_socket = dp + .take_udp_socket_handle(RAPTORCAST_SOCKET) + .expect("raptorcast socket"); + let (dp_reader, dp_writer, _udp_dataplane) = dp.split(); let config = config::RaptorCastConfig { shared_key, mtu: DEFAULT_MTU, @@ -431,6 +442,7 @@ where SecondaryRaptorCastModeConfig::None, dp_reader, dp_writer, + udp_socket, shared_pd, Epoch(0), ) @@ -576,7 +588,7 @@ where .epoch_no(epoch) .build_unicast_msg(&outbound_message, &build_target) { - self.dataplane_writer.udp_write_unicast(rc_chunks); + self.udp_socket.write_unicast(rc_chunks); }; } } @@ -685,8 +697,7 @@ where } loop { - let dataplane = &mut this.dataplane_reader; - let Poll::Ready(message) = pin!(dataplane.udp_read()).poll_unpin(cx) else { + let Poll::Ready(message) = pin!(this.udp_socket.recv()).poll_unpin(cx) else { break; }; @@ -717,14 +728,11 @@ where }) .collect(); - this.dataplane_writer.udp_write_broadcast_with_priority( - BroadcastMsg { - targets: target_addrs, - payload, - stride: bcast_stride, - }, - UdpPriority::High, - ); + this.udp_socket.write_broadcast(BroadcastMsg { + targets: target_addrs, + payload, + stride: bcast_stride, + }); }, message, ) @@ -909,7 +917,7 @@ where }; if let Some(rc_chunks) = rc_chunks { - this.dataplane_writer.udp_write_unicast(rc_chunks); + this.udp_socket.write_unicast(rc_chunks); }; }; diff --git a/monad-raptorcast/src/raptorcast_secondary/mod.rs b/monad-raptorcast/src/raptorcast_secondary/mod.rs index d909e5e559..bdbd5ec299 100644 --- a/monad-raptorcast/src/raptorcast_secondary/mod.rs +++ b/monad-raptorcast/src/raptorcast_secondary/mod.rs @@ -33,7 +33,7 @@ use group_message::FullNodesGroupMessage; use monad_crypto::certificate_signature::{ CertificateKeyPair, CertificateSignaturePubKey, CertificateSignatureRecoverable, }; -use monad_dataplane::{udp::segment_size_for_mtu, DataplaneWriter}; +use monad_dataplane::{udp::segment_size_for_mtu, DataplaneWriter, UdpSocketHandle, UnicastMsg}; use monad_executor::{Executor, ExecutorMetrics, ExecutorMetricsChain}; use monad_executor_glue::{Message, PeerEntry, RouterCommand}; use monad_peer_discovery::{driver::PeerDiscoveryDriver, PeerDiscoveryAlgo, PeerDiscoveryEvent}; @@ -84,6 +84,7 @@ where curr_epoch: Epoch, dataplane_writer: DataplaneWriter, + udp_socket: UdpSocketHandle, peer_discovery_driver: Arc>>, message_builder: OwnedMessageBuilder, @@ -104,6 +105,7 @@ where config: RaptorCastConfig, secondary_mode: SecondaryRaptorCastMode, dataplane_writer: DataplaneWriter, + udp_socket: UdpSocketHandle, peer_discovery_driver: Arc>>, channel_from_primary: UnboundedReceiver>, channel_to_primary: UnboundedSender>, @@ -157,6 +159,7 @@ where curr_epoch: current_epoch, message_builder, dataplane_writer, + udp_socket, peer_discovery_driver, channel_from_primary, metrics: Default::default(), @@ -189,7 +192,7 @@ where .message_builder .build_unicast_msg(&msg_bytes, &build_target) { - self.dataplane_writer.udp_write_unicast(rc_chunks); + self.udp_socket.write_unicast(rc_chunks); } } @@ -398,7 +401,7 @@ where .build_unicast_msg(&outbound_message, &build_target) { // Send the raptorcast chunks via UDP to all peers in group - self.dataplane_writer.udp_write_unicast(rc_chunks); + self.udp_socket.write_unicast(rc_chunks); } } }