diff --git a/msg-socket/src/lib.rs b/msg-socket/src/lib.rs index 5332911e..02c322e3 100644 --- a/msg-socket/src/lib.rs +++ b/msg-socket/src/lib.rs @@ -34,7 +34,10 @@ mod connection; pub use connection::*; /// The default buffer size for a socket. -const DEFAULT_BUFFER_SIZE: usize = 8192; +pub const DEFAULT_BUFFER_SIZE: usize = 8192; + +/// The default queue size for a channel. +pub const DEFAULT_QUEUE_SIZE: usize = 8192; /// A request Identifier. pub struct RequestId(u32); diff --git a/msg-socket/src/rep/driver.rs b/msg-socket/src/rep/driver.rs index 13a57adb..cb7da6a6 100644 --- a/msg-socket/src/rep/driver.rs +++ b/msg-socket/src/rep/driver.rs @@ -1,5 +1,4 @@ use std::{ - collections::VecDeque, io, pin::Pin, sync::Arc, @@ -56,13 +55,45 @@ pub(crate) struct PeerState { write_buffer_size: usize, /// The address of the peer. addr: A, - egress_queue: VecDeque>, + /// Single pending outgoing message waiting to be sent. + pending_egress: Option>, + /// High-water mark for pending responses. When reached, new requests are blocked + /// (backpressure is applied) until responses drain below the limit. + max_pending_responses: usize, state: Arc, /// The optional message compressor. compressor: Option>, span: tracing::Span, } +/// The driver behind a `Rep` socket. +/// +/// # Driver Event Loop +/// The event loop of the driver will try to do as much work as possible in a given call, also +/// accounting for work that might become available as a result of the driver running. +/// +/// There's an implicit priority based on the order in which the tasks are polled, +/// this is because after an inner task has completed some work it will restart the loop, thus +/// allowing earlier tasks to work again. +/// +/// Currently, this driver will use the following "tasks": +/// 1. Connected peers (created in task 2 and 3) +/// 2. Authentication tasks for connecting peers (future created by task 3) +/// 3. Incoming new connections for connecting peers (future created by task 5) +/// 4. Process control signals for the underlying transport (doesn't restart the loop) +/// 5. Incoming connections from the underlying transport +/// +/// ```text +/// (5) Transport ────> (3) conn_tasks ────> (2) auth_tasks +/// │ │ +/// │ (no auth) │ +/// v v +/// ┌─────────────────────────────┐ +/// │ (1) peer_states │ +/// └─────────────────────────────┘ +/// +/// (4) control_rx ──on_control──> Transport +/// ``` #[allow(clippy::type_complexity)] pub(crate) struct RepDriver, A: Address> { /// The server transport used to accept incoming connections. @@ -103,6 +134,7 @@ where let this = self.get_mut(); loop { + // Check connected peers, handle disconnections and incoming request if let Poll::Ready(Some((peer, maybe_result))) = this.peer_states.poll_next_unpin(cx) { let Some(result) = maybe_result.enter() else { debug!(?peer, "peer disconnected"); @@ -141,6 +173,8 @@ where continue; } + // Drive authentication tasks, when authentication succeeds we register the peer as + // connected if let Poll::Ready(Some(Ok(auth))) = this.auth_tasks.poll_join_next(cx).enter() { match auth.inner { Ok(auth) => { @@ -166,7 +200,8 @@ where linger_timer, write_buffer_size: this.options.write_buffer_size, addr: auth.addr, - egress_queue: VecDeque::with_capacity(128), + pending_egress: None, + max_pending_responses: this.options.max_pending_responses, state: Arc::clone(&this.state), compressor: this.compressor.clone(), }), @@ -181,6 +216,7 @@ where continue; } + // Drive accepting incoming connections if let Poll::Ready(Some(conn)) = this.conn_tasks.poll_next_unpin(cx).enter() { match conn.inner { Ok(io) => { @@ -192,8 +228,8 @@ where Err(e) => { debug!(?e, "failed to accept incoming connection"); - // Active clients have already been incremented in the initial call to - // `poll_accept`, so we need to decrement them here. + // Active clients have already been incremented when accepting them from the + // underlying transport, so we need to decrement them here. this.state.stats.specific.decrement_active_clients(); } } @@ -201,6 +237,7 @@ where continue; } + // Drive control signals for the underlying transport if let Poll::Ready(Some(cmd)) = this.control_rx.poll_recv(cx) { this.transport.on_control(cmd); } @@ -210,19 +247,20 @@ where if let Poll::Ready(accept) = Pin::new(&mut this.transport).poll_accept(cx) { let span = this.span.clone().entered(); - if let Some(max) = this.options.max_clients { - if this.state.stats.specific.active_clients() >= max { - warn!( - limit = max, - "max connections reached, rejecting new incoming connection", - ); + // Reject incoming connections if we have `max_clients` active clients already + let active_clients = this.state.stats.specific.active_clients(); + if this.options.max_clients.is_some_and(|max| active_clients >= max) { + warn!( + active_clients, + "max connections reached, rejecting new incoming connection", + ); - continue; - } + continue; } - // Increment the active clients counter. If the authentication fails, this counter - // will be decremented. + // Increment the active clients counter. + // IMPORTANT: decrement the active clients counter when the connection fails or is + // closed. this.state.stats.specific.increment_active_clients(); this.conn_tasks.push(accept.with_span(span)); @@ -240,8 +278,12 @@ where T: Transport, A: Address, { - /// Handles an accepted connection. If this returns an error, the active connections counter + /// Handles an accepted connection. + /// + /// If this returns an error, the active connections counter /// should be decremented. + /// + /// Will schedule an authentication task if `self.auth` is set fn on_accepted_connection(&mut self, io: T::Io) -> Result<(), io::Error> { let addr = io.peer_addr()?; info!(?addr, "new connection"); @@ -253,6 +295,7 @@ where }); let Some(ref auth) = self.auth else { + // Create peer without authenticating self.peer_states.insert( addr.clone(), StreamNotifyClose::new(PeerState { @@ -262,7 +305,8 @@ where linger_timer, write_buffer_size: self.options.write_buffer_size, addr, - egress_queue: VecDeque::with_capacity(128), + pending_egress: None, + max_pending_responses: self.options.max_pending_responses, state: Arc::clone(&self.state), compressor: self.compressor.clone(), }), @@ -313,24 +357,24 @@ where } impl PeerState { - /// Prepares for shutting down by sending and flushing all messages in [`Self::egress_queue`]. + /// Prepares for shutting down by sending and flushing all pending messages. /// When [`Poll::Ready`] is returned, the connection with this peer can be shutdown. /// /// TODO: there might be some [`Self::pending_requests`] yet to processed. TBD how to handle /// them, for now they're dropped. fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<()> { - let messages = std::mem::take(&mut self.egress_queue); + let pending_msg = self.pending_egress.take(); let buffer_size = self.conn.write_buffer().len(); - if messages.is_empty() && buffer_size == 0 { + if pending_msg.is_none() && buffer_size == 0 { debug!("flushed everything, closing connection"); return Poll::Ready(()); } - debug!(messages = ?messages.len(), write_buffer_size = ?buffer_size, "found data to send"); + debug!(has_pending = ?pending_msg.is_some(), write_buffer_size = ?buffer_size, "found data to send"); - for msg in messages { + if let Some(msg) = pending_msg { if let Err(e) = self.conn.start_send_unpin(msg.inner) { - error!(?e, "failed to send final messages to socket, closing"); + error!(?e, "failed to send final message to socket, closing"); return Poll::Ready(()); } } @@ -351,15 +395,15 @@ impl Stream for PeerState let this = self.get_mut(); loop { - let mut progress = false; - if let Some(msg) = this.egress_queue.pop_front().enter() { + // First, try to send the pending egress message if we have one. + if let Some(msg) = this.pending_egress.take().enter() { let msg_len = msg.size(); + debug!(msg_id = msg.id(), "sending response"); match this.conn.start_send_unpin(msg.inner) { Ok(_) => { this.state.stats.specific.increment_tx(msg_len); - - // We might be able to send more queued messages - progress = true; + // Continue to potentially send more or flush + continue; } Err(e) => { this.state.stats.specific.increment_failed_requests(); @@ -392,68 +436,51 @@ impl Stream for PeerState } } - // Then, try to drain the egress queue. - if this.conn.poll_ready_unpin(cx).is_ready() { - if let Some(msg) = this.egress_queue.pop_front().enter() { - let msg_len = msg.size(); - - debug!(msg_id = msg.id(), "sending response"); - match this.conn.start_send_unpin(msg.inner) { - Ok(_) => { - this.state.stats.specific.increment_tx(msg_len); - - // We might be able to send more queued messages - continue; - } - Err(e) => { - this.state.stats.specific.increment_failed_requests(); - error!(?e, "failed to send message to socket"); - // End this stream as we can't send any more messages - return Poll::Ready(None); - } - } - } - } - - // Then we check for completed requests, and push them onto the egress queue. - if let Poll::Ready(Some(result)) = this.pending_requests.poll_next_unpin(cx).enter() { - match result.inner { - Err(_) => tracing::error!("response channel closed unexpectedly"), - Ok(Response { msg_id, mut response }) => { - let mut compression_type = 0; - let len_before = response.len(); - if let Some(ref compressor) = this.compressor { - match compressor.compress(&response) { - Ok(compressed) => { - response = compressed; - compression_type = compressor.compression_type() as u8; - } - Err(e) => { - error!(?e, "failed to compress message"); - continue; + // Check for completed requests, and set pending_egress (only if empty). + if this.pending_egress.is_none() { + if let Poll::Ready(Some(result)) = this.pending_requests.poll_next_unpin(cx).enter() + { + match result.inner { + Err(_) => tracing::error!("response channel closed unexpectedly"), + Ok(Response { msg_id, mut response }) => { + let mut compression_type = 0; + let len_before = response.len(); + if let Some(ref compressor) = this.compressor { + match compressor.compress(&response) { + Ok(compressed) => { + response = compressed; + compression_type = compressor.compression_type() as u8; + } + Err(e) => { + error!(?e, "failed to compress message"); + continue; + } } - } - debug!( - msg_id, - len_before, - len_after = response.len(), - "compressed message" - ) - } + debug!( + msg_id, + len_before, + len_after = response.len(), + "compressed message" + ) + } - debug!(msg_id, "received response to send"); + debug!(msg_id, "received response to send"); - let msg = reqrep::Message::new(msg_id, compression_type, response); - this.egress_queue.push_back(msg.with_span(result.span)); + let msg = reqrep::Message::new(msg_id, compression_type, response); + this.pending_egress = Some(msg.with_span(result.span)); - continue; + continue; + } } } } - // Finally we accept incoming requests from the peer. - { + // Accept incoming requests from the peer. + // Only accept new requests if we're under the HWM for pending responses. + let under_hwm = this.pending_requests.len() < this.max_pending_responses; + + if under_hwm { let _g = this.span.clone().entered(); match this.conn.poll_next_unpin(cx) { Poll::Ready(Some(result)) => { @@ -504,10 +531,15 @@ impl Stream for PeerState } Poll::Pending => {} } - } - - if progress { - continue; + } else { + // At HWM - not polling from underlying connection until responses drain. + // The waker is registered on pending_requests, so we'll wake when responses + // complete. + trace!( + hwm = this.max_pending_responses, + pending = this.pending_requests.len(), + "at high-water mark, not polling from underlying connection until responses drain" + ); } return Poll::Pending; diff --git a/msg-socket/src/rep/mod.rs b/msg-socket/src/rep/mod.rs index d00ab28c..6d731737 100644 --- a/msg-socket/src/rep/mod.rs +++ b/msg-socket/src/rep/mod.rs @@ -8,12 +8,12 @@ use tokio::sync::oneshot; mod driver; mod socket; -mod stats; -use crate::{Profile, stats::SocketStats}; pub use socket::*; + +mod stats; use stats::RepStats; -const DEFAULT_MIN_COMPRESS_SIZE: usize = 8192; +use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, Profile, stats::SocketStats}; /// Errors that can occur when using a reply socket. #[derive(Debug, Error)] @@ -44,19 +44,30 @@ impl RepError { /// The reply socket options. pub struct RepOptions { /// The maximum number of concurrent clients. - max_clients: Option, - min_compress_size: usize, - write_buffer_size: usize, - write_buffer_linger: Option, + pub(crate) max_clients: Option, + /// Minimum payload size in bytes for compression to be used. + /// + /// If the payload is smaller than this threshold, it will not be compressed. + pub(crate) min_compress_size: usize, + /// The size of the write buffer in bytes. + pub(crate) write_buffer_size: usize, + /// The maximum duration between flushes to the underlying transport + pub(crate) write_buffer_linger: Option, + /// High-water mark for pending responses per peer. + /// + /// When this limit is reached, new requests will not be read from the underlying connection + /// until pending responses are fulfilled. + pub(crate) max_pending_responses: usize, } impl Default for RepOptions { fn default() -> Self { Self { max_clients: None, - min_compress_size: DEFAULT_MIN_COMPRESS_SIZE, - write_buffer_size: 8192, + min_compress_size: DEFAULT_BUFFER_SIZE, + write_buffer_size: DEFAULT_BUFFER_SIZE, write_buffer_linger: Some(Duration::from_micros(100)), + max_pending_responses: DEFAULT_QUEUE_SIZE, } } } @@ -108,6 +119,8 @@ impl RepOptions { /// Sets the minimum payload size for compression. /// If the payload is smaller than this value, it will not be compressed. + /// + /// Default: [`DEFAULT_BUFFER_SIZE`] pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self { self.min_compress_size = min_compress_size; self @@ -116,7 +129,7 @@ impl RepOptions { /// Sets the size (max capacity) of the write buffer in bytes. When the buffer is full, it will /// be flushed to the underlying transport. /// - /// Default: 8KiB + /// Default: [`DEFAULT_BUFFER_SIZE`] pub fn with_write_buffer_size(mut self, size: usize) -> Self { self.write_buffer_size = size; self @@ -130,6 +143,16 @@ impl RepOptions { self.write_buffer_linger = duration; self } + + /// Sets the high-water mark for pending responses per peer. When this limit is reached, + /// new requests will not be read from the underlying connection until pending + /// responses are fulfilled. + /// + /// Default: [`DEFAULT_QUEUE_SIZE`] + pub fn with_max_pending_responses(mut self, hwm: usize) -> Self { + self.max_pending_responses = hwm; + self + } } /// The request socket state, shared between the backend task and the socket. diff --git a/msg-socket/src/rep/socket.rs b/msg-socket/src/rep/socket.rs index 82eca03a..50dca9a6 100644 --- a/msg-socket/src/rep/socket.rs +++ b/msg-socket/src/rep/socket.rs @@ -16,7 +16,7 @@ use tokio_stream::StreamMap; use tracing::{debug, warn}; use crate::{ - Authenticator, DEFAULT_BUFFER_SIZE, RepOptions, Request, + Authenticator, DEFAULT_QUEUE_SIZE, RepOptions, Request, rep::{RepError, SocketState, driver::RepDriver}, }; @@ -110,8 +110,8 @@ where /// Binds the socket to the given address. This spawns the socket driver task. pub async fn try_bind(&mut self, addresses: Vec) -> Result<(), RepError> { - let (to_socket, from_backend) = mpsc::channel(DEFAULT_BUFFER_SIZE); - let (control_tx, control_rx) = mpsc::channel(DEFAULT_BUFFER_SIZE); + let (to_socket, from_backend) = mpsc::channel(DEFAULT_QUEUE_SIZE); + let (control_tx, control_rx) = mpsc::channel(DEFAULT_QUEUE_SIZE); let mut transport = self.transport.take().expect("transport has been moved already"); diff --git a/msg-socket/src/req/driver.rs b/msg-socket/src/req/driver.rs index a87bd7a6..259fffc1 100644 --- a/msg-socket/src/req/driver.rs +++ b/msg-socket/src/req/driver.rs @@ -1,5 +1,4 @@ use std::{ - collections::VecDeque, pin::Pin, sync::Arc, task::{Context, Poll, ready}, @@ -14,11 +13,8 @@ use tokio::{ time::Interval, }; -use super::{ReqError, ReqOptions}; -use crate::{ - SendCommand, - req::{SocketState, conn_manager::ConnManager}, -}; +use super::{ReqError, ReqOptions, SendCommand}; +use crate::req::{SocketState, conn_manager::ConnManager}; use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan}; use msg_transport::{Address, Transport}; @@ -42,8 +38,8 @@ pub(crate) struct ReqDriver, A: Address> { pub(crate) conn_manager: ConnManager, /// The timer for the write buffer linger. pub(crate) linger_timer: Option, - /// The outgoing message queue. - pub(crate) egress_queue: VecDeque>, + /// The single pending outgoing message waiting to be sent. + pub(crate) pending_egress: Option>, /// The currently pending requests waiting for a response. pub(crate) pending_requests: FxHashMap>, /// Interval for checking for request timeouts. @@ -135,7 +131,7 @@ where let msg = message.inner.into_wire(self.id_counter); let msg_id = msg.id(); self.id_counter = self.id_counter.wrapping_add(1); - self.egress_queue.push_back(msg.with_span(span.clone())); + self.pending_egress = Some(msg.with_span(span.clone())); self.pending_requests .insert(msg_id, PendingRequest { start, sender: response }.with_span(span)); } @@ -215,15 +211,14 @@ where Poll::Pending => {} } - // NOTE: We try to drain the egress queue first (the `continue`), writing everything to - // the `Framed` internal buffer. When all messages are written, we move on to flushing - // the connection in the block below. We DO NOT rely on the `Framed` internal - // backpressure boundary, because we do not call `poll_ready`. - if let Some(msg) = this.egress_queue.pop_front().enter() { - // Generate the new message + // Try to send the pending egress message if we have one. + // We only hold a single pending message here; the channel serves as the actual queue. + // This pattern ensures we respect backpressure and don't accumulate unbounded messages. + if let Some(msg) = this.pending_egress.take().enter() { let size = msg.size(); tracing::debug!("Sending msg {}", msg.id()); // Write the message to the buffer. + // FIXME: handle restoring message in pending_egress if send/flush fails match channel.start_send_unpin(msg.inner) { Ok(_) => { this.socket_state.stats.specific.increment_tx(size); @@ -233,13 +228,12 @@ where // set the connection to inactive, so that it will be re-tried this.conn_manager.reset_connection(); + continue; } } - - // We might be able to write more queued messages to the buffer. - continue; } + // Flush if write buffer is full according to configured `write_buffer_size` if channel.write_buffer().len() >= this.options.write_buffer_size { if let Poll::Ready(Err(e)) = channel.poll_flush_unpin(cx) { tracing::error!(err = ?e, "Failed to flush connection"); @@ -253,6 +247,7 @@ where } } + // Flush if we have some data and `linger_timer` is ready if let Some(ref mut linger_timer) = this.linger_timer { if !channel.write_buffer().is_empty() && linger_timer.poll_tick(cx).is_ready() { if let Poll::Ready(Err(e)) = channel.poll_flush_unpin(cx) { @@ -267,25 +262,31 @@ where this.check_timeouts(); } - // Check for outgoing messages from the socket handle - match this.from_socket.poll_recv(cx) { - Poll::Ready(Some(cmd)) => { - this.on_send(cmd); + // Check for outgoing messages from the socket handle. + // Only poll for new requests when pending_egress is empty AND we're under HWM to + // maintain backpressure. + let under_hwm = this.pending_requests.len() < this.options.max_pending_requests; - continue; - } - Poll::Ready(None) => { - tracing::debug!( - "socket dropped, shutting down backend and flushing connection" - ); + if this.pending_egress.is_none() && under_hwm { + match this.from_socket.poll_recv(cx) { + Poll::Ready(Some(cmd)) => { + this.on_send(cmd); - if let Some(channel) = this.conn_manager.active_connection() { - let _ = ready!(channel.poll_close_unpin(cx)); + continue; } + Poll::Ready(None) => { + tracing::debug!( + "socket dropped, shutting down backend and flushing connection" + ); + + if let Some(channel) = this.conn_manager.active_connection() { + let _ = ready!(channel.poll_close_unpin(cx)); + } - return Poll::Ready(()); + return Poll::Ready(()); + } + Poll::Pending => {} } - Poll::Pending => {} } return Poll::Pending; diff --git a/msg-socket/src/req/mod.rs b/msg-socket/src/req/mod.rs index c70b9eaa..1130fd97 100644 --- a/msg-socket/src/req/mod.rs +++ b/msg-socket/src/req/mod.rs @@ -23,8 +23,7 @@ pub use socket::*; use crate::{Profile, stats::SocketStats}; use stats::ReqStats; -/// The default buffer size for the socket. -const DEFAULT_BUFFER_SIZE: usize = 1024; +use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE}; pub(crate) static DRIVER_ID: AtomicUsize = AtomicUsize::new(0); @@ -45,6 +44,8 @@ pub enum ReqError { NoValidEndpoints, #[error("Failed to connect to the target endpoint: {0:?}")] Connect(Box), + #[error("High-water mark reached")] + HighWaterMarkReached, } /// A command to send a request message and wait for a response. @@ -108,6 +109,15 @@ pub struct ReqOptions { pub write_buffer_size: usize, /// The linger duration for the write buffer (how long to wait before flushing). pub write_buffer_linger: Option, + /// The size of the channel buffer between the socket and the driver. + /// This controls how many requests can be queued, on top of the current pending requests, + /// before the socket returns [`ReqError::HighWaterMarkReached`]. + pub max_queue_size: usize, + /// High-water mark for pending requests. When this limit is reached, new requests + /// will not be processed and will be queued up to [`max_queue_size`](Self::max_queue_size) + /// elements. Once both limits are reached, new requests will return + /// [`ReqError::HighWaterMarkReached`]. + pub max_pending_requests: usize, } impl ReqOptions { @@ -184,6 +194,8 @@ impl ReqOptions { /// Sets the minimum payload size in bytes for compression to be used. /// /// If the payload is smaller than this threshold, it will not be compressed. + /// + /// Default: [`DEFAULT_BUFFER_SIZE`] pub fn with_min_compress_size(mut self, min_compress_size: usize) -> Self { self.min_compress_size = min_compress_size; self @@ -192,7 +204,7 @@ impl ReqOptions { /// Sets the size (max capacity) of the write buffer in bytes. /// When the buffer is full, it will be flushed to the underlying transport. /// - /// Default: 8KiB + /// Default: [`DEFAULT_BUFFER_SIZE`] pub fn with_write_buffer_size(mut self, size: usize) -> Self { self.write_buffer_size = size; self @@ -206,6 +218,26 @@ impl ReqOptions { self.write_buffer_linger = duration; self } + + /// Sets the size of the channel buffer between the socket and the driver. + /// This controls how many requests can be queued, on top of the current pending requests, + /// before the socket returns [`ReqError::HighWaterMarkReached`]. + /// + /// Default: [`DEFAULT_QUEUE_SIZE`] + pub fn with_max_queue_size(mut self, size: usize) -> Self { + self.max_queue_size = size; + self + } + + /// Sets the high-water mark for pending requests. When this limit is reached, new requests + /// will not be processed and will be queued up to [`Self::with_max_queue_size`] elements. + /// Once both limits are reached, new requests will return [`ReqError::HighWaterMarkReached`]. + /// + /// Default: [`DEFAULT_QUEUE_SIZE`] + pub fn with_max_pending_requests(mut self, hwm: usize) -> Self { + self.max_pending_requests = hwm; + self + } } impl Default for ReqOptions { @@ -214,9 +246,11 @@ impl Default for ReqOptions { conn: ConnOptions::default(), timeout: Duration::from_secs(5), blocking_connect: false, - min_compress_size: 8192, - write_buffer_size: 8192, + min_compress_size: DEFAULT_BUFFER_SIZE, + write_buffer_size: DEFAULT_BUFFER_SIZE, write_buffer_linger: Some(Duration::from_micros(100)), + max_queue_size: DEFAULT_QUEUE_SIZE, + max_pending_requests: DEFAULT_QUEUE_SIZE, } } } diff --git a/msg-socket/src/req/socket.rs b/msg-socket/src/req/socket.rs index 98f54462..34ac461b 100644 --- a/msg-socket/src/req/socket.rs +++ b/msg-socket/src/req/socket.rs @@ -10,7 +10,7 @@ use bytes::Bytes; use rustc_hash::FxHashMap; use tokio::{ net::{ToSocketAddrs, lookup_host}, - sync::{mpsc, oneshot}, + sync::{mpsc, mpsc::error::TrySendError, oneshot}, }; use tokio_util::codec::Framed; @@ -18,7 +18,7 @@ use msg_common::span::WithSpan; use msg_transport::{Address, MeteredIo, Transport}; use msg_wire::{compression::Compressor, reqrep}; -use super::{DEFAULT_BUFFER_SIZE, ReqError, ReqOptions}; +use super::{ReqError, ReqOptions}; use crate::{ ConnectionState, DRIVER_ID, ExponentialBackoff, ReqMessage, SendCommand, req::{ @@ -130,9 +130,11 @@ where self.to_driver .as_ref() .ok_or(ReqError::SocketClosed)? - .send(SendCommand::new(WithSpan::current(msg), response_tx)) - .await - .map_err(|_| ReqError::SocketClosed)?; + .try_send(SendCommand::new(WithSpan::current(msg), response_tx)) + .map_err(|err| match err { + TrySendError::Full(_) => ReqError::HighWaterMarkReached, + TrySendError::Closed(_) => ReqError::SocketClosed, + })?; response_rx.await.map_err(|_| ReqError::SocketClosed)? } @@ -169,8 +171,7 @@ where /// Internal method to initialize and spawn the driver. fn spawn_driver(&mut self, endpoint: A, transport: T, conn_ctl: ConnCtl) { - // Initialize communication channels - let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); + let (to_driver, from_socket) = mpsc::channel(self.options.max_queue_size); let timeout_check_interval = tokio::time::interval(self.options.timeout / 10); @@ -207,7 +208,7 @@ where linger_timer, pending_requests, timeout_check_interval, - egress_queue: Default::default(), + pending_egress: None, compressor: self.compressor.clone(), id, span, diff --git a/msg-socket/src/sub/mod.rs b/msg-socket/src/sub/mod.rs index 6a65791c..9adfe35a 100644 --- a/msg-socket/src/sub/mod.rs +++ b/msg-socket/src/sub/mod.rs @@ -12,8 +12,6 @@ mod socket; pub use socket::*; mod stats; - -use crate::stats::SocketStats; use stats::SubStats; mod stream; @@ -21,7 +19,7 @@ mod stream; use msg_transport::Address; use msg_wire::pubsub; -use crate::DEFAULT_BUFFER_SIZE; +use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_QUEUE_SIZE, stats::SocketStats}; #[derive(Debug, Error)] pub enum SubError { @@ -61,7 +59,7 @@ pub struct SubOptions { auth_token: Option, /// The maximum amount of incoming messages that will be buffered before being dropped due to /// a slow consumer. - ingress_buffer_size: usize, + ingress_queue_size: usize, /// The read buffer size for each session. read_buffer_size: usize, /// The initial backoff for reconnecting to a publisher. @@ -78,15 +76,19 @@ impl SubOptions { self } - /// Sets the ingress buffer size. This is the maximum amount of incoming messages that will be + /// Sets the ingress queue size. This is the maximum amount of incoming messages that will be /// buffered. If the consumer cannot keep up with the incoming messages, messages will start /// being dropped. - pub fn with_ingress_buffer_size(mut self, ingress_buffer_size: usize) -> Self { - self.ingress_buffer_size = ingress_buffer_size; + /// + /// Default: [`DEFAULT_QUEUE_SIZE`] + pub fn with_ingress_queue_size(mut self, ingress_queue_size: usize) -> Self { + self.ingress_queue_size = ingress_queue_size; self } /// Sets the read buffer size. This sets the size of the read buffer for each session. + /// + /// Default: [`DEFAULT_BUFFER_SIZE`] pub fn with_read_buffer_size(mut self, read_buffer_size: usize) -> Self { self.read_buffer_size = read_buffer_size; self @@ -110,7 +112,7 @@ impl Default for SubOptions { fn default() -> Self { Self { auth_token: None, - ingress_buffer_size: DEFAULT_BUFFER_SIZE, + ingress_queue_size: DEFAULT_QUEUE_SIZE, read_buffer_size: 8192, initial_backoff: Duration::from_millis(100), retry_attempts: Some(24), diff --git a/msg-socket/src/sub/socket.rs b/msg-socket/src/sub/socket.rs index 6dd85c5c..f3ce7acb 100644 --- a/msg-socket/src/sub/socket.rs +++ b/msg-socket/src/sub/socket.rs @@ -17,19 +17,9 @@ use tokio::{ use msg_common::{IpAddrExt, JoinMap}; use msg_transport::{Address, Transport}; -// ADDED: Import the specific SubStats struct for the API -use super::stats::SubStats; -// Import the rest from the parent module (sub/mod.rs) -use super::{ - // REMOVED: Old/removed stats structs - // Command, PubMessage, SocketState, SocketStats, SocketWideStats, SubDriver, SubError, - Command, - DEFAULT_BUFFER_SIZE, - PubMessage, - SocketState, - SubDriver, - SubError, - SubOptions, +use crate::sub::{ + Command, DEFAULT_BUFFER_SIZE, PubMessage, SocketState, SubDriver, SubError, SubOptions, + stats::SubStats, }; /// A subscriber socket. This socket implements [`Stream`] and yields incoming [`PubMessage`]s. @@ -136,7 +126,7 @@ where /// Creates a new subscriber socket with the given transport and options. pub fn with_options(transport: T, options: SubOptions) -> Self { let (to_driver, from_socket) = mpsc::channel(DEFAULT_BUFFER_SIZE); - let (to_socket, from_driver) = mpsc::channel(options.ingress_buffer_size); + let (to_socket, from_driver) = mpsc::channel(options.ingress_queue_size); let options = Arc::new(options); diff --git a/msg-socket/tests/it/reqrep.rs b/msg-socket/tests/it/reqrep.rs index a1a67d67..697b7306 100644 --- a/msg-socket/tests/it/reqrep.rs +++ b/msg-socket/tests/it/reqrep.rs @@ -1,7 +1,7 @@ use std::time::Duration; use bytes::Bytes; -use msg_socket::{RepSocket, ReqSocket}; +use msg_socket::{DEFAULT_QUEUE_SIZE, RepSocket, ReqOptions, ReqSocket}; use msg_transport::{ tcp::Tcp, tcp_tls::{self, TcpTls}, @@ -236,3 +236,87 @@ async fn reqrep_late_bind_works() { let hello = Bytes::from_static(b"hello"); assert_eq!(hello, response, "expected {hello:?}, got {response:?}"); } + +/// Tests that the high-water mark for pending requests is enforced. +/// When HWM is reached, new requests should return `HighWaterMarkReached` error. +#[tokio::test] +async fn reqrep_hwm_reached() { + let _ = tracing_subscriber::fmt::try_init(); + + const HWM: usize = 2; + + let mut rep = RepSocket::new(Tcp::default()); + // Set HWM for pending requests + let options = + ReqOptions::default().with_max_pending_requests(HWM).with_timeout(Duration::from_secs(30)); + let mut req = ReqSocket::with_options(Tcp::default(), options); + + rep.bind("0.0.0.0:0").await.unwrap(); + req.connect(rep.local_addr().unwrap()).await.unwrap(); + + // Give time for connection to establish + tokio::time::sleep(Duration::from_millis(100)).await; + + // Spawn the rep handler that won't respond but keep the request alive + tokio::spawn(async move { + let mut requests = Vec::new(); + // Collect requests until we get the signal + loop { + tokio::select! { + Some(request) = rep.next() => { + requests.push(request); + } + } + } + }); + + // Share req via Arc for concurrent access + let req = std::sync::Arc::new(req); + + const TOTAL_CAPACITY: usize = HWM + DEFAULT_QUEUE_SIZE; + + // Send requests until the channel is full (HighWaterMarkReached error) + // - HWM requests will be moved to pending_requests + // - DEFAULT_QUEUE_SIZE requests will be buffered in the channel (driver stops polling at HWM) + // - The next request will fail with HighWaterMarkReached + let mut success_receivers = Vec::new(); + let mut sent_count = 0; + + let (loop_tx, mut loop_rx) = tokio::sync::mpsc::channel(1); + loop { + let (tx, rx) = tokio::sync::oneshot::channel(); + let req_clone = std::sync::Arc::clone(&req); + let loop_tx = loop_tx.clone(); + + let i = sent_count; + + // Spawn the request task - it will block waiting for response + tokio::spawn(async move { + let result = req_clone.request(Bytes::from(format!("request{}", i))).await; + if result.is_err() { + _ = loop_tx.send(()).await; + } + + let _ = tx.send(result); + }); + + success_receivers.push(rx); + sent_count += 1; + + // Give time for the request to be processed by the driver + tokio::time::sleep(Duration::from_millis(1)).await; + + // Check if we received an error from a spawned task + // If so it's either a timeout or HWM limit - since the rep wouldn't respond yet + if loop_rx.try_recv().is_ok() { + break; + } + } + + let expected_limit = TOTAL_CAPACITY + 1; + assert_eq!( + sent_count, expected_limit, + "Expected to send {} requests before HWM, but sent {}", + expected_limit, sent_count + ); +} diff --git a/msg/benches/pubsub.rs b/msg/benches/pubsub.rs index f778f577..2fee6fca 100644 --- a/msg/benches/pubsub.rs +++ b/msg/benches/pubsub.rs @@ -159,7 +159,7 @@ fn pubsub_single_thread_tcp(c: &mut Criterion) { Tcp::default(), SubOptions::default() .with_read_buffer_size(buffer_size) - .with_ingress_buffer_size(N_REQS * 2), + .with_ingress_queue_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -201,7 +201,7 @@ fn pubsub_multi_thread_tcp(c: &mut Criterion) { Tcp::default(), SubOptions::default() .with_read_buffer_size(buffer_size) - .with_ingress_buffer_size(N_REQS * 2), + .with_ingress_queue_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -242,7 +242,7 @@ fn pubsub_single_thread_quic(c: &mut Criterion) { Quic::default(), SubOptions::default() .with_read_buffer_size(buffer_size) - .with_ingress_buffer_size(N_REQS * 2), + .with_ingress_queue_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -284,7 +284,7 @@ fn pubsub_multi_thread_quic(c: &mut Criterion) { Quic::default(), SubOptions::default() .with_read_buffer_size(buffer_size) - .with_ingress_buffer_size(N_REQS * 2), + .with_ingress_queue_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -325,7 +325,7 @@ fn pubsub_single_thread_ipc(c: &mut Criterion) { Ipc::default(), SubOptions::default() .with_read_buffer_size(buffer_size) - .with_ingress_buffer_size(N_REQS * 2), + .with_ingress_queue_size(N_REQS * 2), ); let mut bench = PairBenchmark { @@ -367,7 +367,7 @@ fn pubsub_multi_thread_ipc(c: &mut Criterion) { Ipc::default(), SubOptions::default() .with_read_buffer_size(buffer_size) - .with_ingress_buffer_size(N_REQS * 2), + .with_ingress_queue_size(N_REQS * 2), ); let mut bench = PairBenchmark { diff --git a/msg/examples/pubsub.rs b/msg/examples/pubsub.rs index 90cd97ab..a83050fb 100644 --- a/msg/examples/pubsub.rs +++ b/msg/examples/pubsub.rs @@ -21,13 +21,13 @@ async fn main() { // Configure the subscribers with options let mut sub1 = SubSocket::with_options( Tcp::default(), - SubOptions::default().with_ingress_buffer_size(1024), + SubOptions::default().with_ingress_queue_size(1024), ); let mut sub2 = SubSocket::with_options( // TCP transport with blocking connect, usually connection happens in the background. Tcp::default(), - SubOptions::default().with_ingress_buffer_size(1024), + SubOptions::default().with_ingress_queue_size(1024), ); tracing::info!("Setting up the sockets..."); diff --git a/msg/examples/quic_vs_tcp.rs b/msg/examples/quic_vs_tcp.rs index 155e6b7d..3b8e0a44 100644 --- a/msg/examples/quic_vs_tcp.rs +++ b/msg/examples/quic_vs_tcp.rs @@ -27,7 +27,7 @@ async fn run_tcp() { // Configure the subscribers with options let mut sub1 = SubSocket::with_options( Tcp::default(), - SubOptions::default().with_ingress_buffer_size(1024), + SubOptions::default().with_ingress_queue_size(1024), ); tracing::info!("Setting up the sockets..."); @@ -60,7 +60,7 @@ async fn run_quic() { let mut sub1 = SubSocket::with_options( // TCP transport with blocking connect, usually connection happens in the background. Quic::default(), - SubOptions::default().with_ingress_buffer_size(1024), + SubOptions::default().with_ingress_queue_size(1024), ); tracing::info!("Setting up the sockets...");