Skip to content
Merged
137 changes: 65 additions & 72 deletions msg-socket/src/rep/driver.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
collections::VecDeque,
io,
pin::Pin,
sync::Arc,
Expand Down Expand Up @@ -56,7 +55,10 @@ pub(crate) struct PeerState<T: AsyncRead + AsyncWrite, A: Address> {
write_buffer_size: usize,
/// The address of the peer.
addr: A,
egress_queue: VecDeque<WithSpan<reqrep::Message>>,
/// Single pending outgoing message waiting to be sent.
pending_egress: Option<WithSpan<reqrep::Message>>,
/// High-water mark for pending responses. When reached, new responses are dropped.
pending_responses_hwm: Option<usize>,
state: Arc<SocketState>,
/// The optional message compressor.
compressor: Option<Arc<dyn Compressor>>,
Expand Down Expand Up @@ -166,7 +168,8 @@ where
linger_timer,
write_buffer_size: this.options.write_buffer_size,
addr: auth.addr,
egress_queue: VecDeque::with_capacity(128),
pending_egress: None,
pending_responses_hwm: this.options.pending_responses_hwm,
state: Arc::clone(&this.state),
compressor: this.compressor.clone(),
}),
Expand Down Expand Up @@ -262,7 +265,8 @@ where
linger_timer,
write_buffer_size: self.options.write_buffer_size,
addr,
egress_queue: VecDeque::with_capacity(128),
pending_egress: None,
pending_responses_hwm: self.options.pending_responses_hwm,
state: Arc::clone(&self.state),
compressor: self.compressor.clone(),
}),
Expand Down Expand Up @@ -313,24 +317,24 @@ where
}

impl<T: AsyncRead + AsyncWrite + Unpin, A: Address> PeerState<T, A> {
/// Prepares for shutting down by sending and flushing all messages in [`Self::egress_queue`].
/// Prepares for shutting down by sending and flushing all pending messages.
/// When [`Poll::Ready`] is returned, the connection with this peer can be shutdown.
///
/// TODO: there might be some [`Self::pending_requests`] yet to processed. TBD how to handle
/// them, for now they're dropped.
Comment on lines 363 to 364
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still relevant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because we would still have requests in pending_requests.
This only flushes pending_egress, where before it was a flush of egress_queue. In a way this is the same as egress_queue with size 1.
This TODO just signals that we don't really respond anything to pending_requests, we could for example send a connection reset instead... Or at least that's what I understand here

fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let messages = std::mem::take(&mut self.egress_queue);
let pending_msg = self.pending_egress.take();
let buffer_size = self.conn.write_buffer().len();
if messages.is_empty() && buffer_size == 0 {
if pending_msg.is_none() && buffer_size == 0 {
debug!("flushed everything, closing connection");
return Poll::Ready(());
}

debug!(messages = ?messages.len(), write_buffer_size = ?buffer_size, "found data to send");
debug!(has_pending = ?pending_msg.is_some(), write_buffer_size = ?buffer_size, "found data to send");

for msg in messages {
if let Some(msg) = pending_msg {
if let Err(e) = self.conn.start_send_unpin(msg.inner) {
error!(?e, "failed to send final messages to socket, closing");
error!(?e, "failed to send final message to socket, closing");
return Poll::Ready(());
}
}
Expand All @@ -351,15 +355,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin, A: Address + Unpin> Stream for PeerState
let this = self.get_mut();

loop {
let mut progress = false;
if let Some(msg) = this.egress_queue.pop_front().enter() {
// First, try to send the pending egress message if we have one.
if let Some(msg) = this.pending_egress.take().enter() {
let msg_len = msg.size();
debug!(msg_id = msg.id(), "sending response");
match this.conn.start_send_unpin(msg.inner) {
Ok(_) => {
this.state.stats.specific.increment_tx(msg_len);

// We might be able to send more queued messages
progress = true;
// Continue to potentially send more or flush
continue;
}
Err(e) => {
this.state.stats.specific.increment_failed_requests();
Expand Down Expand Up @@ -392,68 +396,54 @@ impl<T: AsyncRead + AsyncWrite + Unpin, A: Address + Unpin> Stream for PeerState
}
}

// Then, try to drain the egress queue.
if this.conn.poll_ready_unpin(cx).is_ready() {
if let Some(msg) = this.egress_queue.pop_front().enter() {
let msg_len = msg.size();

debug!(msg_id = msg.id(), "sending response");
match this.conn.start_send_unpin(msg.inner) {
Ok(_) => {
this.state.stats.specific.increment_tx(msg_len);

// We might be able to send more queued messages
continue;
}
Err(e) => {
this.state.stats.specific.increment_failed_requests();
error!(?e, "failed to send message to socket");
// End this stream as we can't send any more messages
return Poll::Ready(None);
}
}
}
}

// Then we check for completed requests, and push them onto the egress queue.
if let Poll::Ready(Some(result)) = this.pending_requests.poll_next_unpin(cx).enter() {
match result.inner {
Err(_) => tracing::error!("response channel closed unexpectedly"),
Ok(Response { msg_id, mut response }) => {
let mut compression_type = 0;
let len_before = response.len();
if let Some(ref compressor) = this.compressor {
match compressor.compress(&response) {
Ok(compressed) => {
response = compressed;
compression_type = compressor.compression_type() as u8;
}
Err(e) => {
error!(?e, "failed to compress message");
continue;
// Check for completed requests, and set pending_egress (only if empty).
if this.pending_egress.is_none() {
if let Poll::Ready(Some(result)) = this.pending_requests.poll_next_unpin(cx).enter()
{
match result.inner {
Err(_) => tracing::error!("response channel closed unexpectedly"),
Ok(Response { msg_id, mut response }) => {
let mut compression_type = 0;
let len_before = response.len();
if let Some(ref compressor) = this.compressor {
match compressor.compress(&response) {
Ok(compressed) => {
response = compressed;
compression_type = compressor.compression_type() as u8;
}
Err(e) => {
error!(?e, "failed to compress message");
continue;
}
}
}

debug!(
msg_id,
len_before,
len_after = response.len(),
"compressed message"
)
}
debug!(
msg_id,
len_before,
len_after = response.len(),
"compressed message"
)
}

debug!(msg_id, "received response to send");
debug!(msg_id, "received response to send");

let msg = reqrep::Message::new(msg_id, compression_type, response);
this.egress_queue.push_back(msg.with_span(result.span));
let msg = reqrep::Message::new(msg_id, compression_type, response);
this.pending_egress = Some(msg.with_span(result.span));

continue;
continue;
}
}
}
}

// Finally we accept incoming requests from the peer.
{
// Accept incoming requests from the peer.
// Only accept new requests if we're under the HWM for pending responses.
let under_hwm = this
.pending_responses_hwm
.map(|hwm| this.pending_requests.len() < hwm)
.unwrap_or(true);

if under_hwm {
let _g = this.span.clone().entered();
match this.conn.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => {
Expand Down Expand Up @@ -504,10 +494,13 @@ impl<T: AsyncRead + AsyncWrite + Unpin, A: Address + Unpin> Stream for PeerState
}
Poll::Pending => {}
}
}

if progress {
continue;
} else {
// At HWM - log warning and don't accept new requests until responses drain
trace!(
hwm = ?this.pending_responses_hwm,
pending = this.pending_requests.len(),
"at high-water mark, not accepting new requests"
);
}

return Poll::Pending;
Expand Down
21 changes: 17 additions & 4 deletions msg-socket/src/rep/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ impl RepError {
/// The reply socket options.
pub struct RepOptions {
/// The maximum number of concurrent clients.
max_clients: Option<usize>,
min_compress_size: usize,
write_buffer_size: usize,
write_buffer_linger: Option<Duration>,
pub(crate) max_clients: Option<usize>,
pub(crate) min_compress_size: usize,
pub(crate) write_buffer_size: usize,
pub(crate) write_buffer_linger: Option<Duration>,
/// High-water mark for pending responses per peer. When this limit is reached,
/// new responses will be dropped. If `None`, there is no limit (unbounded).
pub(crate) pending_responses_hwm: Option<usize>,
}

impl Default for RepOptions {
Expand All @@ -57,6 +60,7 @@ impl Default for RepOptions {
min_compress_size: DEFAULT_MIN_COMPRESS_SIZE,
write_buffer_size: 8192,
write_buffer_linger: Some(Duration::from_micros(100)),
pending_responses_hwm: None,
}
}
}
Expand Down Expand Up @@ -130,6 +134,15 @@ impl RepOptions {
self.write_buffer_linger = duration;
self
}

/// Sets the high-water mark for pending responses per peer. When this limit is reached,
/// new responses will be dropped. If `None`, there is no limit (unbounded).
///
/// Default: `None`
pub fn with_pending_responses_hwm(mut self, hwm: usize) -> Self {
self.pending_responses_hwm = Some(hwm);
self
}
}

/// The request socket state, shared between the backend task and the socket.
Expand Down
78 changes: 46 additions & 32 deletions msg-socket/src/req/driver.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
collections::VecDeque,
pin::Pin,
sync::Arc,
task::{Context, Poll, ready},
Expand All @@ -14,11 +13,8 @@ use tokio::{
time::Interval,
};

use super::{ReqError, ReqOptions};
use crate::{
SendCommand,
req::{SocketState, conn_manager::ConnManager},
};
use super::{ReqError, ReqOptions, SendCommand};
use crate::req::{SocketState, conn_manager::ConnManager};

use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan};
use msg_transport::{Address, Transport};
Expand All @@ -42,8 +38,8 @@ pub(crate) struct ReqDriver<T: Transport<A>, A: Address> {
pub(crate) conn_manager: ConnManager<T, A>,
/// The timer for the write buffer linger.
pub(crate) linger_timer: Option<Interval>,
/// The outgoing message queue.
pub(crate) egress_queue: VecDeque<WithSpan<reqrep::Message>>,
/// The single pending outgoing message waiting to be sent.
pub(crate) pending_egress: Option<WithSpan<reqrep::Message>>,
/// The currently pending requests waiting for a response.
pub(crate) pending_requests: FxHashMap<u32, WithSpan<PendingRequest>>,
/// Interval for checking for request timeouts.
Expand Down Expand Up @@ -106,8 +102,23 @@ where
}

/// Handle an incoming command from the socket frontend.
fn on_send(&mut self, cmd: SendCommand) {
/// Returns `true` if the command was accepted, `false` if HWM was reached.
fn on_send(&mut self, cmd: SendCommand) -> bool {
let SendCommand { mut message, response } = cmd;

// Check high-water mark before accepting the request
if let Some(hwm) = self.options.pending_requests_hwm {
if self.pending_requests.len() >= hwm {
tracing::warn!(
hwm,
pending = self.pending_requests.len(),
"high-water mark reached, rejecting request"
);
let _ = response.send(Err(ReqError::HighWaterMarkReached(hwm)));
return false;
}
}

let start = Instant::now();

// We want ot inherit the span from the socket frontend
Expand Down Expand Up @@ -135,9 +146,11 @@ where
let msg = message.inner.into_wire(self.id_counter);
let msg_id = msg.id();
self.id_counter = self.id_counter.wrapping_add(1);
self.egress_queue.push_back(msg.with_span(span.clone()));
self.pending_egress = Some(msg.with_span(span.clone()));
self.pending_requests
.insert(msg_id, PendingRequest { start, sender: response }.with_span(span));

true
}

/// Check for request timeouts and notify the sender if any requests have timed out.
Expand Down Expand Up @@ -215,12 +228,10 @@ where
Poll::Pending => {}
}

// NOTE: We try to drain the egress queue first (the `continue`), writing everything to
// the `Framed` internal buffer. When all messages are written, we move on to flushing
// the connection in the block below. We DO NOT rely on the `Framed` internal
// backpressure boundary, because we do not call `poll_ready`.
if let Some(msg) = this.egress_queue.pop_front().enter() {
// Generate the new message
// Try to send the pending egress message if we have one.
// We only hold a single pending message here; the channel serves as the actual queue.
// This pattern ensures we respect backpressure and don't accumulate unbounded messages.
if let Some(msg) = this.pending_egress.take().enter() {
let size = msg.size();
tracing::debug!("Sending msg {}", msg.id());
// Write the message to the buffer.
Expand All @@ -236,7 +247,7 @@ where
}
}

// We might be able to write more queued messages to the buffer.
// Continue to potentially send more or flush
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This continue is now redundant - we won't be able to write more since pending_egress is now None. However, in the error case above, we should continue to poll the conn_manager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, in the error case, you should put the message back into pending_egress or it gets lost

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not super sure I understand.
If we continue after writing pending_egress we will poll conn_manager again and also poll the given channel to drive it.

At the end, since we don't have pending_egress we will get another message when polling from_socket, processing that one...
Basically we try to do as much work as possible before getting "stuck" waiting for the underlying I/O, no?

Also, in the error case, you should put the message back into pending_egress or it gets lost

Does that apply also to all other error cases then, like when failing to flush? I need to double check how channel works here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we continue after writing pending_egress we will poll conn_manager again and also poll the given channel to drive it.

But we already have the channel in this code path, so there's no need to poll it again. The conn_manager is responsible for ensuring we stay connected in case of a disconnect, so when we get here we don't have to poll it.

At the end, since we don't have pending_egress we will get another message when polling from_socket, processing that one...
Basically we try to do as much work as possible before getting "stuck" waiting for the underlying I/O, no?

Yes, but in this case what work are you doing by continuing after writing to the channel?

Does that apply also to all other error cases then, like when failing to flush? I need to double check how channel works here

That's a good catch, technically we should yes. Maybe there's a way to take the Framed buffer and cache it to be written again when we reconnect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I thought conn manager handled multiple channels, depending on what was ready. Will account for it then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I understood it now. If we were to still continue we'd just get the channel again, poll for incoming messages, but not really prepare anything else to write out. Honestly it's ok but we can instead drive the rest, which would also include getting a new message from the queue in the pending_egress...
In this scenario, we'll try flushing after we reach the write buffer size or after linger_timer anyways, with potentially N messages written (if the write and from_socket keep being ready)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH wrt restoring the egress on error/flush, I'm unsure. I think it needs further testing so I'd leave it for a new issue.
I think we'd change the semantics too much rn, since it has been like this even before this pending_egress addition, just with multiple messages at once.

}

Expand Down Expand Up @@ -267,25 +278,28 @@ where
this.check_timeouts();
}

// Check for outgoing messages from the socket handle
match this.from_socket.poll_recv(cx) {
Poll::Ready(Some(cmd)) => {
this.on_send(cmd);
// Check for outgoing messages from the socket handle.
// Only poll when pending_egress is empty to maintain backpressure.
if this.pending_egress.is_none() {
match this.from_socket.poll_recv(cx) {
Poll::Ready(Some(cmd)) => {
this.on_send(cmd);

continue;
}
Poll::Ready(None) => {
tracing::debug!(
"socket dropped, shutting down backend and flushing connection"
);

if let Some(channel) = this.conn_manager.active_connection() {
let _ = ready!(channel.poll_close_unpin(cx));
continue;
}
Poll::Ready(None) => {
tracing::debug!(
"socket dropped, shutting down backend and flushing connection"
);

return Poll::Ready(());
if let Some(channel) = this.conn_manager.active_connection() {
let _ = ready!(channel.poll_close_unpin(cx));
}

return Poll::Ready(());
}
Poll::Pending => {}
}
Poll::Pending => {}
}

return Poll::Pending;
Expand Down
Loading
Loading