From 4a3f5f4d6327f8bcfaa2c64e9bf117d4ac7ad063 Mon Sep 17 00:00:00 2001 From: Marius Eriksen Date: Mon, 18 Aug 2025 02:43:30 -0700 Subject: [PATCH] [hyperactor] net: zero copy framer This change introduces a zero-copy framer, which will also be easily extendable to doing vectorized framing of multipart messages. We eschew tokio's framer (which relies on queuing for cancellation safety), in favor of a simple(r) implementation: the reader maintains a simple state machine, while the writer requires the caller to maintain an explicit write state (since this has to be driven across selects in order to be made cancellation safe). In this way, we have an easily hackable framer that does not introduce additional queues. Differential Revision: [D80365228](https://our.internmc.facebook.com/intern/diff/D80365228/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D80365228/)! [ghstack-poisoned] --- hyperactor/src/channel/net.rs | 420 +++++++++++++++------------ hyperactor/src/channel/net/framed.rs | 203 +++++++++++++ 2 files changed, 442 insertions(+), 181 deletions(-) create mode 100644 hyperactor/src/channel/net/framed.rs diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index 3557af98f..1febcb7e0 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -18,12 +18,14 @@ //! ``` //! //! Thus, each socket connection is a sequence of such framed messages. + use std::any::type_name; use std::collections::VecDeque; use std::fmt; use std::fmt::Debug; use std::future::Future; use std::io; +use std::mem::replace; use std::mem::take; use std::net::ToSocketAddrs; use std::pin::Pin; @@ -32,19 +34,20 @@ use std::task::Poll; use backoff::ExponentialBackoffBuilder; use backoff::backoff::Backoff; +use bytes::Buf; use bytes::BufMut; use bytes::Bytes; use bytes::BytesMut; use dashmap::DashMap; use dashmap::mapref::entry::Entry; use enum_as_inner::EnumAsInner; -use futures::Sink; -use futures::SinkExt; -use futures::stream::SplitSink; -use futures::stream::SplitStream; use serde::de::Error; use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; +use tokio::io::ReadHalf; +use tokio::io::WriteHalf; use tokio::sync::Mutex; use tokio::sync::MutexGuard; use tokio::sync::watch; @@ -53,9 +56,6 @@ use tokio::task::JoinHandle; use tokio::task::JoinSet; use tokio::time::Duration; use tokio::time::Instant; -use tokio_util::codec::Decoder; -use tokio_util::codec::Encoder; -use tokio_util::codec::Framed; use tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio_util::net::Listener; use tokio_util::sync::CancellationToken; @@ -68,6 +68,10 @@ use crate::clock::RealClock; use crate::config; use crate::metrics; +mod framed; +use framed::FrameReader; +use framed::FrameWrite; + /// Use to prevent [futures::Stream] objects using the wrong next() method by /// accident. Bascially, we want to use [tokio_stream::StreamExt::next] since it /// is cancel safe. However, there is another trait, [futures::StreamExt::next], @@ -115,59 +119,24 @@ enum Frame { Message(u64, M), } -#[derive(thiserror::Error, Debug)] -enum FrameError { - #[error(transparent)] - Bincode(#[from] bincode::Error), - #[error("framer: {0}")] - Framer(E), - #[error("eof")] - Eof, -} - -impl Frame { - async fn send(&self, framer: &mut Framed) -> Result<(), FrameError> - where - T: AsyncWrite + std::marker::Unpin, - U: Encoder, - I: From>, - U::Error: From, - { - let data = bincode::serialize(self).expect("unexpected serialization error"); - framer.send(data.into()).await.map_err(FrameError::Framer) - } - - async fn next( - stream: &mut SplitStream>, - ) -> Result> - where - T: AsyncRead + std::marker::Unpin, - U: Decoder, - U::Item: Into>, - { - match tokio_stream::StreamExt::next(stream).await { - Some(Ok(data)) => Ok(bincode::deserialize(&data.into())?), - Some(Err(error)) => Err(FrameError::Framer(error)), - None => Err(FrameError::Eof), - } - } -} - fn serialize_ack(seq: u64) -> Bytes { let mut data = BytesMut::with_capacity(8); - data.put_slice(&seq.to_be_bytes()); + data.put_u64(seq); data.freeze() } -fn deserialize_ack(data: BytesMut) -> Result { - let slice = data.as_ref(); - let array: [u8; 8] = slice.try_into().map_err(|_| slice.len())?; - Ok(u64::from_be_bytes(array)) +fn deserialize_ack(mut data: Bytes) -> Result { + if data.len() != 8 { + return Err(data.len()); + } + Ok(data.get_u64()) } fn build_codec() -> LengthDelimitedCodec { LengthDelimitedCodec::builder() .max_frame_length(config::global::get(config::CODEC_MAX_FRAME_LENGTH)) + .length_field_length(8) + .big_endian() .new_codec() } @@ -247,25 +216,8 @@ impl NetTx { self.deque.is_empty() } - // Send the oldest message in the outbox, but do not remove it from - // the outbox. Return error if the outbox is empty. - async fn send_message + Unpin>(&self, sink: &mut T) -> Result<(), String> - where - T::Error: fmt::Display, - { - let data = self - .deque - .front() - .ok_or_else(|| { - format!( - "{}: unexpected: send_message cannot be used when outbox is empty", - self.log_id, - ) - })? - .data - .clone(); - sink.send(data).await.map_err(|e| e.to_string())?; - Ok(()) + fn front_bytes(&self) -> Option { + self.deque.front().map(|msg| msg.data.clone()) } fn front_size(&self) -> Option { @@ -476,8 +428,8 @@ impl NetTx { Disconnected(Box), /// Connected and ready to go. Connected { - sink: SplitSink, Bytes>, - stream: SplitStream>, + reader: FrameReader>, + write_state: WriteState, ()>, }, } @@ -544,14 +496,32 @@ impl NetTx { conn, ), }, + ( + State::Running(Deliveries { outbox, unacked }), + Conn::Connected { + reader, + write_state: WriteState::Idle(writer), + .. + }, + ) if !outbox.is_empty() => { + let body = outbox.front_bytes().unwrap(); + ( + State::Running(Deliveries { outbox, unacked }), + Conn::Connected { + reader, + // Dequeue the next message to be sent: + write_state: WriteState::Writing(FrameWrite::new(writer, body), ()), + }, + ) + } ( State::Running(Deliveries { mut outbox, mut unacked, }), Conn::Connected { - mut sink, - mut stream, + mut reader, + mut write_state, }, ) => { tokio::select! { @@ -567,16 +537,15 @@ impl NetTx { (State::Closing { deliveries: Deliveries{outbox, unacked}, reason: error_msg, - }, Conn::Connected { sink, stream }) + }, Conn::Connected { reader, write_state }) } - // tokio_stream::StreamExt::next is cancel safe. - ack_result = tokio_stream::StreamExt::next(&mut stream) => { + ack_result = reader.next() => { match ack_result { - Some(Ok(data)) => { - match deserialize_ack(data) { + Ok(Some(buffer)) => { + match deserialize_ack(buffer) { Ok(ack) => { unacked.prune(ack); - (State::Running(Deliveries { outbox, unacked }), Conn::Connected { sink, stream }) + (State::Running(Deliveries { outbox, unacked }), Conn::Connected { reader, write_state }) } Err(len) => { let error_msg = format!( @@ -591,13 +560,17 @@ impl NetTx { (State::Closing { deliveries: Deliveries{outbox, unacked}, reason: error_msg, - }, Conn::Connected { sink, stream }) + }, Conn::Connected { reader, write_state }) } } - }, - Some(Err(err)) => { + } + Ok(None) => { + // Graceful of stream: reconnect + (State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default()) + } + Err(err) => { tracing::error!( - "session {}.{}: failed to receiving ack: {}", + "session {}.{}: failed while receiving ack: {}", link.dest(), session_id, err @@ -605,10 +578,6 @@ impl NetTx { // Reconnect and wish the error will go away. (State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default()) } - None => { - // None means connection is closed. Reconnect. - (State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default()) - } } }, // It does matter whether `fn send_message` is cancel safe or not. Since @@ -616,12 +585,12 @@ impl NetTx { // canceled, the message will not be dropped. In the worst case, the same // message would get sent multiple times. But that is okay. The seq order is // still preserved. - send_result = outbox.send_message(&mut sink), if !outbox.is_empty() => { + send_result = write_state.send() => { match send_result { Ok(()) => { let message = outbox.pop_front().expect("outbox should not be empty"); unacked.push_back(message); - (State::Running(Deliveries { outbox, unacked }), Conn::Connected { sink, stream }) + (State::Running(Deliveries { outbox, unacked }), Conn::Connected { reader, write_state }) } Err(err) => { tracing::info!( @@ -647,7 +616,7 @@ impl NetTx { outbox, unacked, }); - (running, Conn::Connected { sink, stream }) + (running, Conn::Connected { reader, write_state }) } Err(err) => { let error_msg = format!( @@ -660,14 +629,14 @@ impl NetTx { (State::Closing { deliveries: Deliveries {outbox, unacked}, reason: error_msg, - }, Conn::Connected { sink, stream }) + }, Conn::Connected { reader, write_state }) } } } None => (State::Closing { deliveries: Deliveries{outbox, unacked}, reason: "NetTx is dropped".to_string(), - }, Conn::Connected { sink, stream }), + }, Conn::Connected { reader, write_state }), } }, } @@ -715,11 +684,12 @@ impl NetTx { } else { match link.connect().await { Ok(stream) => { - let framed = Framed::new(stream, build_codec()); - let (mut sink, stream) = futures::StreamExt::split(framed); - let data = bincode::serialize(&Frame::::Init(session_id)) - .expect("unexpected serialization error"); - let initialized = sink.send(data.into()).await.is_ok(); + let frame = + bincode::serialize(&Frame::::Init(session_id)).unwrap(); + + let mut write = FrameWrite::new(stream, frame.into()); + let initialized = write.send().await.is_ok(); + let stream = write.complete(); metrics::CHANNEL_CONNECTIONS.add( 1, @@ -742,7 +712,14 @@ impl NetTx { }), if initialized { backoff.reset(); - Conn::Connected { sink, stream } + let (reader, writer) = tokio::io::split(stream); + Conn::Connected { + reader: FrameReader::new( + reader, + config::global::get(config::CODEC_MAX_FRAME_LENGTH), + ), + write_state: WriteState::Idle(writer), + } } else { Conn::reconnect(backoff) }, @@ -822,11 +799,18 @@ impl NetTx { } if let Conn::Connected { - mut sink, - stream: _, + write_state: WriteState::Writing(mut frame_writer, ()), + .. } = conn { - if let Err(err) = sink.flush().await { + if let Err(err) = frame_writer.send().await { + tracing::info!( + "session {}.{}: write error: {}", + link.dest(), + session_id, + err + ); + } else if let Err(err) = frame_writer.complete().flush().await { tracing::info!( "session {}.{}: flush error: {}", link.dest(), @@ -1000,20 +984,47 @@ pub enum ClientError { Serialize(ChannelAddr, bincode::ErrorKind), } +#[derive(EnumAsInner)] +enum WriteState { + /// No frame being written. + Idle(W), + /// Currently writing a frame, with associated T-typed value. + Writing(FrameWrite, T), + + /// Internal state to manage completions. + Broken, +} + +impl WriteState { + async fn send(&mut self) -> io::Result { + match self { + Self::Idle(_) => futures::future::pending().await, + Self::Writing(fw, value) => { + fw.send().await?; + let Ok((fw, value)) = replace(self, Self::Broken).into_writing() else { + panic!("illegal state"); + }; + *self = Self::Idle(fw.complete()); + Ok(value) + } + Self::Broken => panic!("illegal state"), + } + } +} + struct ServerConn { - sink: SplitSink, Bytes>, - stream: SplitStream>, + reader: FrameReader>, + write_state: WriteState, u64>, source: ChannelAddr, dest: ChannelAddr, } impl ServerConn { fn new(stream: S, source: ChannelAddr, dest: ChannelAddr) -> Self { - let framed = Framed::new(stream, build_codec()); - let (sink, stream) = futures::StreamExt::split(framed); + let (reader, writer) = tokio::io::split(stream); Self { - sink, - stream, + reader: FrameReader::new(reader, config::global::get(config::CODEC_MAX_FRAME_LENGTH)), + write_state: WriteState::Idle(writer), source, dest, } @@ -1022,7 +1033,10 @@ impl ServerConn { impl ServerConn { async fn handshake(&mut self) -> Result { - let Frame::Init(session_id) = Frame::::next(&mut self.stream).await? else { + let Some(frame) = self.reader.next().await? else { + anyhow::bail!("end of stream before first frame from {}", self.source); + }; + let Frame::Init(session_id) = bincode::deserialize::>(&frame)? else { anyhow::bail!("unexpected initial frame from {}", self.source); }; Ok(session_id) @@ -1043,71 +1057,81 @@ impl ServerConn { let ack_msg_interval = config::global::get(config::MESSAGE_ACK_EVERY_N_MESSAGES); let (mut final_next, final_result) = loop { + if self.write_state.is_idle() + && (next.ack + ack_msg_interval <= next.seq + || (next.ack < next.seq && last_ack_time.elapsed() > ack_time_interval)) + { + let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() + else { + panic!("illegal state"); + }; + self.write_state = WriteState::Writing( + FrameWrite::new(writer, serialize_ack(next.seq - 1)), + next.seq, + ); + } + tokio::select! { - // tokio_stream::StreamExt::next is cancel safe. - rcv_result = tokio_stream::StreamExt::next(&mut self.stream) => { - match rcv_result { - Some(Ok(data)) => { - match bincode::deserialize(&data) { - Ok(Frame::Init(_)) => { - break (next, Err(anyhow::anyhow!("unexpected init frame from {}", self.source))) - }, - // Ignore retransmits. - Ok(Frame::Message(seq, _)) if seq < next.seq => (), - // The following segment ensures exactly-once semantics. - // That means No out-of-order delivery and no duplicate delivery. - Ok(Frame::Message(seq, message)) => { - // received seq should be equal to next seq. Else error out! - if seq > next.seq { - tracing::error!("out-of-sequence message from {}", self.source); - let next_seq = next.seq; - break (next, Err(anyhow::anyhow!("out-of-sequence message from {}, expected seq {}, got {}", self.source, next_seq, seq))) - } - match tx.send(message).await { - Ok(()) => { - // In channel's contract, "delivered" means the message - // is sent to the NetRx object. Therefore, we could bump - // `next_seq` as far as the message is put on the mspc - // channel. - // - // Note that when/how the messages in NetRx are processed - // is not covered by channel's contract. For example, - // the message might never be taken out of netRx, but - // channel still considers those messages delivered. - next.seq = seq+1; - } - Err(err) => { - break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error relaying message to mspc channel for {:?}", self.source))) - } - } - }, - Err(err) => break ( - next, - Err::<(), anyhow::Error>(err.into()).context( - format!( - "error deserializing into Frame with M = {} for data from {:?}", - type_name::(), - self.source, - ) - ) - ), + bytes = self.reader.next() => { + let frame = match bytes { + Ok(bytes) => bytes.map(|buf| bincode::deserialize(&buf)).transpose(), + Err(e) => Err(e.into()), + }; + match frame { + Ok(Some(Frame::Init(_))) => { + break (next, Err(anyhow::anyhow!("unexpected init frame from {}", self.source))) + }, + // Ignore retransmits. + Ok(Some(Frame::Message(seq, _))) if seq < next.seq => (), + // The following segment ensures exactly-once semantics. + // That means No out-of-order delivery and no duplicate delivery. + Ok(Some(Frame::Message(seq, message))) => { + // received seq should be equal to next seq. Else error out! + if seq > next.seq { + tracing::error!("out-of-sequence message from {}", self.source); + let next_seq = next.seq; + break (next, Err(anyhow::anyhow!("out-of-sequence message from {}, expected seq {}, got {}", self.source, next_seq, seq))) } - } - Some(Err(err)) => { - break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error receiving peer message from {:?}", self.source))) - } + match tx.send(message).await { + Ok(()) => { + // In channel's contract, "delivered" means the message + // is sent to the NetRx object. Therefore, we could bump + // `next_seq` as far as the message is put on the mspc + // channel. + // + // Note that when/how the messages in NetRx are processed + // is not covered by channel's contract. For example, + // the message might never be taken out of netRx, but + // channel still considers those messages delivered. + next.seq = seq+1; + } + Err(err) => { + break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error relaying message to mspc channel for {:?}", self.source))) + } + } + }, - None => break (next, Ok(())) + Ok(None) => break (next, Ok(())), + + Err(err) => break ( + next, + Err::<(), anyhow::Error>(err.into()).context( + format!( + "error reading into Frame with M = {} for data from {:?}", + type_name::(), + self.source, + ) + ) + ), } } // It does matter whether send_ack is cancel safe. If it is not, // the same seq might get acked multiple times. But that is okay. - ack_result = Self::send_ack(&mut self.sink, next.seq), if next.ack + ack_msg_interval <= next.seq || - (next.ack < next.seq && last_ack_time.elapsed() > ack_time_interval) => { + ack_result = self.write_state.send() => { match ack_result { - Ok(()) => { + Ok(acked_seq) => { last_ack_time = RealClock.now(); - next.ack = next.seq; + next.ack = acked_seq; } Err(err) => { break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error acking peer message from {:?}", self.source))) @@ -1120,11 +1144,23 @@ impl ServerConn { _ = cancel_token.cancelled() => break (next, Ok(())) } }; + // Flush any ongoing write. + if self.write_state.is_writing() { + let _ = self.write_state.send().await; + } // best effort: "flush" any remaining ack before closing this session - if final_next.ack < final_next.seq { - match Self::send_ack(&mut self.sink, final_next.seq).await { - Ok(()) => { - final_next.ack = final_next.seq; + if self.write_state.is_idle() && final_next.ack < final_next.seq { + let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() else { + panic!("illegal state"); + }; + self.write_state = WriteState::Writing( + FrameWrite::new(writer, serialize_ack(final_next.seq - 1)), + final_next.seq, + ); + + match self.write_state.send().await { + Ok(acked_seq) => { + final_next.ack = acked_seq; } Err(e) => { tracing::warn!( @@ -1144,14 +1180,6 @@ impl ServerConn { } (final_next, final_result) } - - async fn send_ack( - sink: &mut SplitSink, bytes::Bytes>, - next_seq: u64, - ) -> Result<(), std::io::Error> { - let serialized = serialize_ack(next_seq - 1); - futures::SinkExt::send(sink, serialized).await - } } /// An MVar is a primitive that combines synchronization and the exchange @@ -1994,11 +2022,18 @@ mod tests { #[cfg(target_os = "linux")] // uses abstract names use anyhow::Result; + use futures::Sink; + use futures::SinkExt; + use futures::stream::SplitSink; + use futures::stream::SplitStream; use rand::Rng; use rand::SeedableRng; use rand::distributions::Alphanumeric; use timed_test::async_timed_test; use tokio::io::DuplexStream; + use tokio_util::codec::Decoder; + use tokio_util::codec::Encoder; + use tokio_util::codec::Framed; use super::*; @@ -2416,7 +2451,7 @@ mod tests { tracing::debug!("MockLink relays a msg from client. msg: {:?}", msg); } } else { - let result = deserialize_ack(data.clone()); + let result = deserialize_ack(data.clone().into()); if let Ok(seq) = result { tracing::debug!("MockLink relays an ack from server. seq: {:?}", seq); } @@ -2590,7 +2625,8 @@ mod tests { tokio_stream::StreamExt::next(framed) .await .unwrap() - .unwrap(), + .unwrap() + .into(), ) .unwrap(); @@ -2611,6 +2647,11 @@ mod tests { { let (handle, mut framed, mut rx, _cancel_token) = serve(&manager).await; + let handle = tokio::spawn(async move { + let result = handle.await.unwrap(); + result + }); + write_stream( &mut framed, session_id, @@ -2642,6 +2683,12 @@ mod tests { // Now, create a new connection with the same session. { let (handle, mut framed, mut rx, cancel_token) = serve(&manager).await; + let handle = tokio::spawn(async move { + let result = handle.await.unwrap(); + eprintln!("handle joined with: {:?}", result); + result + }); + write_stream( &mut framed, session_id, @@ -2690,7 +2737,8 @@ mod tests { tokio_stream::StreamExt::next(&mut framed) .await .unwrap() - .unwrap(), + .unwrap() + .into(), ) .unwrap(); assert_eq!(acked, i); @@ -2761,7 +2809,12 @@ mod tests { loc: u32, ) { let expected = Frame::Message(expect.0, expect.1); - let frame = Frame::::next(stream).await.unwrap(); + let data = tokio_stream::StreamExt::next(stream) + .await + .unwrap() + .unwrap(); + let frame: Frame = bincode::deserialize(data.as_ref()).unwrap(); + assert_eq!(frame, expected, "from ln={loc}"); } @@ -2772,7 +2825,12 @@ mod tests { loc: u32, ) -> u64 { let session_id = { - let frame = Frame::::next(stream).await.unwrap(); + let data = tokio_stream::StreamExt::next(stream) + .await + .unwrap() + .unwrap(); + let frame: Frame = bincode::deserialize(data.as_ref()).unwrap(); + match frame { Frame::Init(session_id) => session_id, _ => panic!("the 1st frame is not Init: {:?}. from ln={loc}", frame), diff --git a/hyperactor/src/channel/net/framed.rs b/hyperactor/src/channel/net/framed.rs new file mode 100644 index 000000000..cf13869d2 --- /dev/null +++ b/hyperactor/src/channel/net/framed.rs @@ -0,0 +1,203 @@ +//! This module implements a cancellation-safe zero-copy framer for network channels. + +use std::io; +use std::mem::take; + +use bytes::Buf; +use bytes::BufMut; +use bytes::Bytes; +use bytes::BytesMut; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; + +/// A FrameReader reads frames from an underlying [ AsyncRead ]. +pub struct FrameReader { + reader: R, + max_frame_length: usize, + state: FrameReaderState, +} + +enum FrameReaderState { + /// Accumulating 8-byte length prefix. + ReadLen { buf: BytesMut }, // buf.len() <= 8 + /// Accumulating body of exactly `len` bytes. + ReadBody { len: usize, buf: BytesMut }, // buf.len() <= len +} + +impl FrameReader { + /// Create a new framer for `reader`. Frames exceeding `max_frame_length` + /// in length result in an irrecoverable reader error. + pub fn new(reader: R, max_frame_length: usize) -> Self { + Self { + reader, + max_frame_length, + state: FrameReaderState::ReadLen { + buf: BytesMut::with_capacity(8), + }, + } + } + + /// Read the next frame from the underlying reader. If the frame exceeds + /// the configured maximum length, `next` returns an `io::ErrorKind::InvalidData` + /// error. + /// + /// The method is cancellation safe in the sense that, if it is used in a branch + /// of a `tokio::select!` block, frames are never dropped. + pub async fn next(&mut self) -> io::Result> { + loop { + match &mut self.state { + FrameReaderState::ReadLen { buf } if buf.len() < 8 => { + let n = self.reader.read_buf(buf).await?; + + // https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf + // "This reader has reached its “end of file” and will likely no longer + // be able to produce bytes. Note that this does not mean that the reader + // will always no longer be able to produce bytes." + // + // In practice, this means EOF. + if n == 0 { + if buf.is_empty() { + // We ended on a frame boundary. End of stream: + return Ok(None); + } else { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + } + } + + FrameReaderState::ReadLen { buf } => { + let len = buf.get_u64() as usize; + if len > self.max_frame_length { + return Err(io::ErrorKind::InvalidData.into()); + } + self.state = FrameReaderState::ReadBody { + len, + buf: BytesMut::with_capacity(len), + }; + } + + FrameReaderState::ReadBody { len, buf } if buf.len() < *len => { + let n = self.reader.read_buf(buf).await?; + if n == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + } + + FrameReaderState::ReadBody { len, buf } if buf.len() == *len => { + let frame = take(buf).freeze(); + self.state = FrameReaderState::ReadLen { + buf: BytesMut::with_capacity(8), + }; + return Ok(Some(frame)); + } + _ => panic!("impossible state"), + } + } + } +} + +/// A Writer for message frames. FrameWrite requires the user to drive +/// the underlying state machines through (possibly) successive calls to +/// `send`, retaining cancellation safety. The FrameWrite owns the underlying +/// writer until the frame has been written to completion. +pub struct FrameWrite { + writer: W, + state: FrameWriteState, +} +enum FrameWriteState { + /// Writing frame length. + WriteLen { len_buf: Bytes, body: Bytes }, + /// Writing the frame body. + WriteBody { body: Bytes }, +} + +impl FrameWrite { + /// Create a new frame writer, writing `body` to `writer`. + pub fn new(writer: W, body: Bytes) -> Self { + let mut len_buf = BytesMut::with_capacity(8); + len_buf.put_u64(body.len() as u64); + let len_buf = len_buf.freeze(); + Self { + writer, + state: FrameWriteState::WriteLen { len_buf, body }, + } + } + + /// Drive the underlying state machine. The frame is written when this + /// `send` returns successfully. + /// + /// This method is cancellation safe in the sense that each invocation to `send` + /// preserves progress (the future can be safely dropped at any time). Thus, the + /// user can drive the state machine by calling `send` multiple times, dropping the + /// returned futures at any time. Upon completion, the frame is guaranteed to be + /// written, unless an error was encountered, in which case the underlying stream + /// is in an undefined state. + pub async fn send(&mut self) -> io::Result<()> { + loop { + match &mut self.state { + FrameWriteState::WriteLen { len_buf, .. } if !len_buf.is_empty() => { + self.writer.write_all_buf(len_buf).await?; + } + FrameWriteState::WriteLen { body, .. } => { + self.state = FrameWriteState::WriteBody { + body: body.clone(), // cheap, but let's get rid of it + } + } + FrameWriteState::WriteBody { body } if !body.is_empty() => { + self.writer.write_all_buf(body).await?; + } + FrameWriteState::WriteBody { .. } => { + return Ok(()); + } + } + } + } + + /// Complete the write, returning ownership of the underlying writer. + /// This should only be called after successful sends; at other times + /// the underlying stream is in an undefined state. + pub fn complete(self) -> W { + let Self { writer, .. } = self; + writer + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + use rand::thread_rng; + + use super::*; + + fn random_buffer(max_len: usize) -> Bytes { + let mut rng = thread_rng(); + let len = rng.gen_range(0..max_len); + let mut buf = vec![0u8; len]; + rng.fill(buf.as_mut_slice()); + Bytes::from(buf) + } + + #[tokio::test] + async fn test_framer_roundtrip() { + const MAX_LEN: usize = 1024; + + let (reader, writer) = tokio::io::duplex(MAX_LEN + 8); + let mut reader = FrameReader::new(reader, MAX_LEN); + + let mut writer = Some(writer); + + for _ in 0..1024 { + let body = random_buffer(MAX_LEN); + let mut frame_write = FrameWrite::new(writer.take().unwrap(), body.clone()); + frame_write.send().await.unwrap(); + writer = Some(frame_write.complete()); + + let frame = reader.next().await.unwrap().unwrap(); + assert_eq!(frame, body); + } + } + + // todo: test cancellation, frame size +}