diff --git a/hyperactor/src/channel/net.rs b/hyperactor/src/channel/net.rs index 3557af98f..42b62e4dc 100644 --- a/hyperactor/src/channel/net.rs +++ b/hyperactor/src/channel/net.rs @@ -18,12 +18,14 @@ //! ``` //! //! Thus, each socket connection is a sequence of such framed messages. + use std::any::type_name; use std::collections::VecDeque; use std::fmt; use std::fmt::Debug; use std::future::Future; use std::io; +use std::mem::replace; use std::mem::take; use std::net::ToSocketAddrs; use std::pin::Pin; @@ -32,19 +34,20 @@ use std::task::Poll; use backoff::ExponentialBackoffBuilder; use backoff::backoff::Backoff; +use bytes::Buf; use bytes::BufMut; use bytes::Bytes; use bytes::BytesMut; use dashmap::DashMap; use dashmap::mapref::entry::Entry; use enum_as_inner::EnumAsInner; -use futures::Sink; -use futures::SinkExt; -use futures::stream::SplitSink; -use futures::stream::SplitStream; use serde::de::Error; use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; +use tokio::io::ReadHalf; +use tokio::io::WriteHalf; use tokio::sync::Mutex; use tokio::sync::MutexGuard; use tokio::sync::watch; @@ -53,9 +56,6 @@ use tokio::task::JoinHandle; use tokio::task::JoinSet; use tokio::time::Duration; use tokio::time::Instant; -use tokio_util::codec::Decoder; -use tokio_util::codec::Encoder; -use tokio_util::codec::Framed; use tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio_util::net::Listener; use tokio_util::sync::CancellationToken; @@ -68,6 +68,10 @@ use crate::clock::RealClock; use crate::config; use crate::metrics; +mod framed; +use framed::FrameReader; +use framed::FrameWrite; + /// Use to prevent [futures::Stream] objects using the wrong next() method by /// accident. Bascially, we want to use [tokio_stream::StreamExt::next] since it /// is cancel safe. However, there is another trait, [futures::StreamExt::next], @@ -115,59 +119,24 @@ enum Frame { Message(u64, M), } -#[derive(thiserror::Error, Debug)] -enum FrameError { - #[error(transparent)] - Bincode(#[from] bincode::Error), - #[error("framer: {0}")] - Framer(E), - #[error("eof")] - Eof, -} - -impl Frame { - async fn send(&self, framer: &mut Framed) -> Result<(), FrameError> - where - T: AsyncWrite + std::marker::Unpin, - U: Encoder, - I: From>, - U::Error: From, - { - let data = bincode::serialize(self).expect("unexpected serialization error"); - framer.send(data.into()).await.map_err(FrameError::Framer) - } - - async fn next( - stream: &mut SplitStream>, - ) -> Result> - where - T: AsyncRead + std::marker::Unpin, - U: Decoder, - U::Item: Into>, - { - match tokio_stream::StreamExt::next(stream).await { - Some(Ok(data)) => Ok(bincode::deserialize(&data.into())?), - Some(Err(error)) => Err(FrameError::Framer(error)), - None => Err(FrameError::Eof), - } - } -} - fn serialize_ack(seq: u64) -> Bytes { let mut data = BytesMut::with_capacity(8); - data.put_slice(&seq.to_be_bytes()); + data.put_u64(seq); data.freeze() } -fn deserialize_ack(data: BytesMut) -> Result { - let slice = data.as_ref(); - let array: [u8; 8] = slice.try_into().map_err(|_| slice.len())?; - Ok(u64::from_be_bytes(array)) +fn deserialize_ack(mut data: Bytes) -> Result { + if data.len() != 8 { + return Err(data.len()); + } + Ok(data.get_u64()) } fn build_codec() -> LengthDelimitedCodec { LengthDelimitedCodec::builder() .max_frame_length(config::global::get(config::CODEC_MAX_FRAME_LENGTH)) + .length_field_length(8) + .big_endian() .new_codec() } @@ -247,25 +216,8 @@ impl NetTx { self.deque.is_empty() } - // Send the oldest message in the outbox, but do not remove it from - // the outbox. Return error if the outbox is empty. - async fn send_message + Unpin>(&self, sink: &mut T) -> Result<(), String> - where - T::Error: fmt::Display, - { - let data = self - .deque - .front() - .ok_or_else(|| { - format!( - "{}: unexpected: send_message cannot be used when outbox is empty", - self.log_id, - ) - })? - .data - .clone(); - sink.send(data).await.map_err(|e| e.to_string())?; - Ok(()) + fn front_bytes(&self) -> Option { + self.deque.front().map(|msg| msg.data.clone()) } fn front_size(&self) -> Option { @@ -476,8 +428,8 @@ impl NetTx { Disconnected(Box), /// Connected and ready to go. Connected { - sink: SplitSink, Bytes>, - stream: SplitStream>, + reader: FrameReader>, + write_state: WriteState, ()>, }, } @@ -544,14 +496,32 @@ impl NetTx { conn, ), }, + ( + State::Running(Deliveries { outbox, unacked }), + Conn::Connected { + reader, + write_state: WriteState::Idle(writer), + .. + }, + ) if !outbox.is_empty() => { + let body = outbox.front_bytes().unwrap(); + ( + State::Running(Deliveries { outbox, unacked }), + Conn::Connected { + reader, + // Dequeue the next message to be sent: + write_state: WriteState::Writing(FrameWrite::new(writer, body), ()), + }, + ) + } ( State::Running(Deliveries { mut outbox, mut unacked, }), Conn::Connected { - mut sink, - mut stream, + mut reader, + mut write_state, }, ) => { tokio::select! { @@ -567,16 +537,15 @@ impl NetTx { (State::Closing { deliveries: Deliveries{outbox, unacked}, reason: error_msg, - }, Conn::Connected { sink, stream }) + }, Conn::Connected { reader, write_state }) } - // tokio_stream::StreamExt::next is cancel safe. - ack_result = tokio_stream::StreamExt::next(&mut stream) => { + ack_result = reader.next() => { match ack_result { - Some(Ok(data)) => { - match deserialize_ack(data) { + Ok(Some(buffer)) => { + match deserialize_ack(buffer) { Ok(ack) => { unacked.prune(ack); - (State::Running(Deliveries { outbox, unacked }), Conn::Connected { sink, stream }) + (State::Running(Deliveries { outbox, unacked }), Conn::Connected { reader, write_state }) } Err(len) => { let error_msg = format!( @@ -591,13 +560,17 @@ impl NetTx { (State::Closing { deliveries: Deliveries{outbox, unacked}, reason: error_msg, - }, Conn::Connected { sink, stream }) + }, Conn::Connected { reader, write_state }) } } - }, - Some(Err(err)) => { + } + Ok(None) => { + // Graceful of stream: reconnect + (State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default()) + } + Err(err) => { tracing::error!( - "session {}.{}: failed to receiving ack: {}", + "session {}.{}: failed while receiving ack: {}", link.dest(), session_id, err @@ -605,23 +578,17 @@ impl NetTx { // Reconnect and wish the error will go away. (State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default()) } - None => { - // None means connection is closed. Reconnect. - (State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default()) - } } }, - // It does matter whether `fn send_message` is cancel safe or not. Since - // `fn send_message` does not remove the message from outbox, when it is - // canceled, the message will not be dropped. In the worst case, the same - // message would get sent multiple times. But that is okay. The seq order is - // still preserved. - send_result = outbox.send_message(&mut sink), if !outbox.is_empty() => { + + // We have to be careful to manage outgoing write states, so that we never write + // partial frames in the presence cancellation. + send_result = write_state.send() => { match send_result { Ok(()) => { let message = outbox.pop_front().expect("outbox should not be empty"); unacked.push_back(message); - (State::Running(Deliveries { outbox, unacked }), Conn::Connected { sink, stream }) + (State::Running(Deliveries { outbox, unacked }), Conn::Connected { reader, write_state }) } Err(err) => { tracing::info!( @@ -647,7 +614,7 @@ impl NetTx { outbox, unacked, }); - (running, Conn::Connected { sink, stream }) + (running, Conn::Connected { reader, write_state }) } Err(err) => { let error_msg = format!( @@ -660,14 +627,14 @@ impl NetTx { (State::Closing { deliveries: Deliveries {outbox, unacked}, reason: error_msg, - }, Conn::Connected { sink, stream }) + }, Conn::Connected { reader, write_state }) } } } None => (State::Closing { deliveries: Deliveries{outbox, unacked}, reason: "NetTx is dropped".to_string(), - }, Conn::Connected { sink, stream }), + }, Conn::Connected { reader, write_state }), } }, } @@ -715,11 +682,12 @@ impl NetTx { } else { match link.connect().await { Ok(stream) => { - let framed = Framed::new(stream, build_codec()); - let (mut sink, stream) = futures::StreamExt::split(framed); - let data = bincode::serialize(&Frame::::Init(session_id)) - .expect("unexpected serialization error"); - let initialized = sink.send(data.into()).await.is_ok(); + let frame = + bincode::serialize(&Frame::::Init(session_id)).unwrap(); + + let mut write = FrameWrite::new(stream, frame.into()); + let initialized = write.send().await.is_ok(); + let stream = write.complete(); metrics::CHANNEL_CONNECTIONS.add( 1, @@ -742,7 +710,14 @@ impl NetTx { }), if initialized { backoff.reset(); - Conn::Connected { sink, stream } + let (reader, writer) = tokio::io::split(stream); + Conn::Connected { + reader: FrameReader::new( + reader, + config::global::get(config::CODEC_MAX_FRAME_LENGTH), + ), + write_state: WriteState::Idle(writer), + } } else { Conn::reconnect(backoff) }, @@ -822,11 +797,18 @@ impl NetTx { } if let Conn::Connected { - mut sink, - stream: _, + write_state: WriteState::Writing(mut frame_writer, ()), + .. } = conn { - if let Err(err) = sink.flush().await { + if let Err(err) = frame_writer.send().await { + tracing::info!( + "session {}.{}: write error: {}", + link.dest(), + session_id, + err + ); + } else if let Err(err) = frame_writer.complete().flush().await { tracing::info!( "session {}.{}: flush error: {}", link.dest(), @@ -1000,20 +982,47 @@ pub enum ClientError { Serialize(ChannelAddr, bincode::ErrorKind), } +#[derive(EnumAsInner)] +enum WriteState { + /// No frame being written. + Idle(W), + /// Currently writing a frame, with associated T-typed value. + Writing(FrameWrite, T), + + /// Internal state to manage completions. + Broken, +} + +impl WriteState { + async fn send(&mut self) -> io::Result { + match self { + Self::Idle(_) => futures::future::pending().await, + Self::Writing(fw, value) => { + fw.send().await?; + let Ok((fw, value)) = replace(self, Self::Broken).into_writing() else { + panic!("illegal state"); + }; + *self = Self::Idle(fw.complete()); + Ok(value) + } + Self::Broken => panic!("illegal state"), + } + } +} + struct ServerConn { - sink: SplitSink, Bytes>, - stream: SplitStream>, + reader: FrameReader>, + write_state: WriteState, u64>, source: ChannelAddr, dest: ChannelAddr, } impl ServerConn { fn new(stream: S, source: ChannelAddr, dest: ChannelAddr) -> Self { - let framed = Framed::new(stream, build_codec()); - let (sink, stream) = futures::StreamExt::split(framed); + let (reader, writer) = tokio::io::split(stream); Self { - sink, - stream, + reader: FrameReader::new(reader, config::global::get(config::CODEC_MAX_FRAME_LENGTH)), + write_state: WriteState::Idle(writer), source, dest, } @@ -1022,7 +1031,10 @@ impl ServerConn { impl ServerConn { async fn handshake(&mut self) -> Result { - let Frame::Init(session_id) = Frame::::next(&mut self.stream).await? else { + let Some(frame) = self.reader.next().await? else { + anyhow::bail!("end of stream before first frame from {}", self.source); + }; + let Frame::Init(session_id) = bincode::deserialize::>(&frame)? else { anyhow::bail!("unexpected initial frame from {}", self.source); }; Ok(session_id) @@ -1043,71 +1055,81 @@ impl ServerConn { let ack_msg_interval = config::global::get(config::MESSAGE_ACK_EVERY_N_MESSAGES); let (mut final_next, final_result) = loop { + if self.write_state.is_idle() + && (next.ack + ack_msg_interval <= next.seq + || (next.ack < next.seq && last_ack_time.elapsed() > ack_time_interval)) + { + let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() + else { + panic!("illegal state"); + }; + self.write_state = WriteState::Writing( + FrameWrite::new(writer, serialize_ack(next.seq - 1)), + next.seq, + ); + } + tokio::select! { - // tokio_stream::StreamExt::next is cancel safe. - rcv_result = tokio_stream::StreamExt::next(&mut self.stream) => { - match rcv_result { - Some(Ok(data)) => { - match bincode::deserialize(&data) { - Ok(Frame::Init(_)) => { - break (next, Err(anyhow::anyhow!("unexpected init frame from {}", self.source))) - }, - // Ignore retransmits. - Ok(Frame::Message(seq, _)) if seq < next.seq => (), - // The following segment ensures exactly-once semantics. - // That means No out-of-order delivery and no duplicate delivery. - Ok(Frame::Message(seq, message)) => { - // received seq should be equal to next seq. Else error out! - if seq > next.seq { - tracing::error!("out-of-sequence message from {}", self.source); - let next_seq = next.seq; - break (next, Err(anyhow::anyhow!("out-of-sequence message from {}, expected seq {}, got {}", self.source, next_seq, seq))) - } - match tx.send(message).await { - Ok(()) => { - // In channel's contract, "delivered" means the message - // is sent to the NetRx object. Therefore, we could bump - // `next_seq` as far as the message is put on the mspc - // channel. - // - // Note that when/how the messages in NetRx are processed - // is not covered by channel's contract. For example, - // the message might never be taken out of netRx, but - // channel still considers those messages delivered. - next.seq = seq+1; - } - Err(err) => { - break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error relaying message to mspc channel for {:?}", self.source))) - } - } - }, - Err(err) => break ( - next, - Err::<(), anyhow::Error>(err.into()).context( - format!( - "error deserializing into Frame with M = {} for data from {:?}", - type_name::(), - self.source, - ) - ) - ), + bytes = self.reader.next() => { + let frame = match bytes { + Ok(bytes) => bytes.map(|buf| bincode::deserialize(&buf)).transpose(), + Err(e) => Err(e.into()), + }; + match frame { + Ok(Some(Frame::Init(_))) => { + break (next, Err(anyhow::anyhow!("unexpected init frame from {}", self.source))) + }, + // Ignore retransmits. + Ok(Some(Frame::Message(seq, _))) if seq < next.seq => (), + // The following segment ensures exactly-once semantics. + // That means No out-of-order delivery and no duplicate delivery. + Ok(Some(Frame::Message(seq, message))) => { + // received seq should be equal to next seq. Else error out! + if seq > next.seq { + tracing::error!("out-of-sequence message from {}", self.source); + let next_seq = next.seq; + break (next, Err(anyhow::anyhow!("out-of-sequence message from {}, expected seq {}, got {}", self.source, next_seq, seq))) } - } - Some(Err(err)) => { - break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error receiving peer message from {:?}", self.source))) - } + match tx.send(message).await { + Ok(()) => { + // In channel's contract, "delivered" means the message + // is sent to the NetRx object. Therefore, we could bump + // `next_seq` as far as the message is put on the mspc + // channel. + // + // Note that when/how the messages in NetRx are processed + // is not covered by channel's contract. For example, + // the message might never be taken out of netRx, but + // channel still considers those messages delivered. + next.seq = seq+1; + } + Err(err) => { + break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error relaying message to mspc channel for {:?}", self.source))) + } + } + }, + + Ok(None) => break (next, Ok(())), - None => break (next, Ok(())) + Err(err) => break ( + next, + Err::<(), anyhow::Error>(err.into()).context( + format!( + "error reading into Frame with M = {} for data from {:?}", + type_name::(), + self.source, + ) + ) + ), } } - // It does matter whether send_ack is cancel safe. If it is not, - // the same seq might get acked multiple times. But that is okay. - ack_result = Self::send_ack(&mut self.sink, next.seq), if next.ack + ack_msg_interval <= next.seq || - (next.ack < next.seq && last_ack_time.elapsed() > ack_time_interval) => { + // We have to be careful to manage the ack write state here, so that we do not + // write partial acks in the presence of cancellation. + ack_result = self.write_state.send() => { match ack_result { - Ok(()) => { + Ok(acked_seq) => { last_ack_time = RealClock.now(); - next.ack = next.seq; + next.ack = acked_seq; } Err(err) => { break (next, Err::<(), anyhow::Error>(err.into()).context(format!("error acking peer message from {:?}", self.source))) @@ -1120,11 +1142,23 @@ impl ServerConn { _ = cancel_token.cancelled() => break (next, Ok(())) } }; + // Flush any ongoing write. + if self.write_state.is_writing() { + let _ = self.write_state.send().await; + } // best effort: "flush" any remaining ack before closing this session - if final_next.ack < final_next.seq { - match Self::send_ack(&mut self.sink, final_next.seq).await { - Ok(()) => { - final_next.ack = final_next.seq; + if self.write_state.is_idle() && final_next.ack < final_next.seq { + let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() else { + panic!("illegal state"); + }; + self.write_state = WriteState::Writing( + FrameWrite::new(writer, serialize_ack(final_next.seq - 1)), + final_next.seq, + ); + + match self.write_state.send().await { + Ok(acked_seq) => { + final_next.ack = acked_seq; } Err(e) => { tracing::warn!( @@ -1144,14 +1178,6 @@ impl ServerConn { } (final_next, final_result) } - - async fn send_ack( - sink: &mut SplitSink, bytes::Bytes>, - next_seq: u64, - ) -> Result<(), std::io::Error> { - let serialized = serialize_ack(next_seq - 1); - futures::SinkExt::send(sink, serialized).await - } } /// An MVar is a primitive that combines synchronization and the exchange @@ -1994,11 +2020,15 @@ mod tests { #[cfg(target_os = "linux")] // uses abstract names use anyhow::Result; + use futures::SinkExt; + use futures::stream::SplitSink; + use futures::stream::SplitStream; use rand::Rng; use rand::SeedableRng; use rand::distributions::Alphanumeric; use timed_test::async_timed_test; use tokio::io::DuplexStream; + use tokio_util::codec::Framed; use super::*; @@ -2416,7 +2446,7 @@ mod tests { tracing::debug!("MockLink relays a msg from client. msg: {:?}", msg); } } else { - let result = deserialize_ack(data.clone()); + let result = deserialize_ack(data.clone().into()); if let Ok(seq) = result { tracing::debug!("MockLink relays an ack from server. seq: {:?}", seq); } @@ -2590,7 +2620,8 @@ mod tests { tokio_stream::StreamExt::next(framed) .await .unwrap() - .unwrap(), + .unwrap() + .into(), ) .unwrap(); @@ -2642,6 +2673,12 @@ mod tests { // Now, create a new connection with the same session. { let (handle, mut framed, mut rx, cancel_token) = serve(&manager).await; + let handle = tokio::spawn(async move { + let result = handle.await.unwrap(); + eprintln!("handle joined with: {:?}", result); + result + }); + write_stream( &mut framed, session_id, @@ -2690,7 +2727,8 @@ mod tests { tokio_stream::StreamExt::next(&mut framed) .await .unwrap() - .unwrap(), + .unwrap() + .into(), ) .unwrap(); assert_eq!(acked, i); @@ -2761,7 +2799,12 @@ mod tests { loc: u32, ) { let expected = Frame::Message(expect.0, expect.1); - let frame = Frame::::next(stream).await.unwrap(); + let data = tokio_stream::StreamExt::next(stream) + .await + .unwrap() + .unwrap(); + let frame: Frame = bincode::deserialize(data.as_ref()).unwrap(); + assert_eq!(frame, expected, "from ln={loc}"); } @@ -2772,7 +2815,12 @@ mod tests { loc: u32, ) -> u64 { let session_id = { - let frame = Frame::::next(stream).await.unwrap(); + let data = tokio_stream::StreamExt::next(stream) + .await + .unwrap() + .unwrap(); + let frame: Frame = bincode::deserialize(data.as_ref()).unwrap(); + match frame { Frame::Init(session_id) => session_id, _ => panic!("the 1st frame is not Init: {:?}. from ln={loc}", frame), diff --git a/hyperactor/src/channel/net/framed.rs b/hyperactor/src/channel/net/framed.rs new file mode 100644 index 000000000..b8f2ba6a2 --- /dev/null +++ b/hyperactor/src/channel/net/framed.rs @@ -0,0 +1,211 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! This module implements a cancellation-safe zero-copy framer for network channels. + +use std::io; +use std::mem::take; + +use bytes::Buf; +use bytes::BufMut; +use bytes::Bytes; +use bytes::BytesMut; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; + +/// A FrameReader reads frames from an underlying [`AsyncRead`]. +pub struct FrameReader { + reader: R, + max_frame_length: usize, + state: FrameReaderState, +} + +enum FrameReaderState { + /// Accumulating 8-byte length prefix. + ReadLen { buf: BytesMut }, // buf.len() <= 8 + /// Accumulating body of exactly `len` bytes. + ReadBody { len: usize, buf: BytesMut }, // buf.len() <= len +} + +impl FrameReader { + /// Create a new framer for `reader`. Frames exceeding `max_frame_length` + /// in length result in an irrecoverable reader error. + pub fn new(reader: R, max_frame_length: usize) -> Self { + Self { + reader, + max_frame_length, + state: FrameReaderState::ReadLen { + buf: BytesMut::with_capacity(8), + }, + } + } + + /// Read the next frame from the underlying reader. If the frame exceeds + /// the configured maximum length, `next` returns an `io::ErrorKind::InvalidData` + /// error. + /// + /// The method is cancellation safe in the sense that, if it is used in a branch + /// of a `tokio::select!` block, frames are never dropped. + pub async fn next(&mut self) -> io::Result> { + loop { + match &mut self.state { + FrameReaderState::ReadLen { buf } if buf.len() < 8 => { + let n = self.reader.read_buf(buf).await?; + + // https://docs.rs/tokio/latest/tokio/io/trait.AsyncReadExt.html#method.read_buf + // "This reader has reached its “end of file” and will likely no longer + // be able to produce bytes. Note that this does not mean that the reader + // will always no longer be able to produce bytes." + // + // In practice, this means EOF. + if n == 0 { + if buf.is_empty() { + // We ended on a frame boundary. End of stream: + return Ok(None); + } else { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + } + } + + FrameReaderState::ReadLen { buf } => { + let len = buf.get_u64() as usize; + if len > self.max_frame_length { + return Err(io::ErrorKind::InvalidData.into()); + } + self.state = FrameReaderState::ReadBody { + len, + buf: BytesMut::with_capacity(len), + }; + } + + FrameReaderState::ReadBody { len, buf } if buf.len() < *len => { + let n = self.reader.read_buf(buf).await?; + if n == 0 { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + } + + FrameReaderState::ReadBody { len, buf } if buf.len() == *len => { + let frame = take(buf).freeze(); + self.state = FrameReaderState::ReadLen { + buf: BytesMut::with_capacity(8), + }; + return Ok(Some(frame)); + } + _ => panic!("impossible state"), + } + } + } +} + +/// A Writer for message frames. FrameWrite requires the user to drive +/// the underlying state machines through (possibly) successive calls to +/// `send`, retaining cancellation safety. The FrameWrite owns the underlying +/// writer until the frame has been written to completion. +pub struct FrameWrite { + writer: W, + state: FrameWriteState, +} +enum FrameWriteState { + /// Writing frame length. + WriteLen { len_buf: Bytes, body: Bytes }, + /// Writing the frame body. + WriteBody { body: Bytes }, +} + +impl FrameWrite { + /// Create a new frame writer, writing `body` to `writer`. + pub fn new(writer: W, body: Bytes) -> Self { + let mut len_buf = BytesMut::with_capacity(8); + len_buf.put_u64(body.len() as u64); + let len_buf = len_buf.freeze(); + Self { + writer, + state: FrameWriteState::WriteLen { len_buf, body }, + } + } + + /// Drive the underlying state machine. The frame is written when this + /// `send` returns successfully. + /// + /// This method is cancellation safe in the sense that each invocation to `send` + /// preserves progress (the future can be safely dropped at any time). Thus, the + /// user can drive the state machine by calling `send` multiple times, dropping the + /// returned futures at any time. Upon completion, the frame is guaranteed to be + /// written, unless an error was encountered, in which case the underlying stream + /// is in an undefined state. + pub async fn send(&mut self) -> io::Result<()> { + loop { + match &mut self.state { + FrameWriteState::WriteLen { len_buf, .. } if !len_buf.is_empty() => { + self.writer.write_all_buf(len_buf).await?; + } + FrameWriteState::WriteLen { body, .. } => { + self.state = FrameWriteState::WriteBody { + body: body.clone(), // cheap, but let's get rid of it + } + } + FrameWriteState::WriteBody { body } if !body.is_empty() => { + self.writer.write_all_buf(body).await?; + } + FrameWriteState::WriteBody { .. } => { + return Ok(()); + } + } + } + } + + /// Complete the write, returning ownership of the underlying writer. + /// This should only be called after successful sends; at other times + /// the underlying stream is in an undefined state. + pub fn complete(self) -> W { + let Self { writer, .. } = self; + writer + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + use rand::thread_rng; + + use super::*; + + fn random_buffer(max_len: usize) -> Bytes { + let mut rng = thread_rng(); + let len = rng.gen_range(0..max_len); + let mut buf = vec![0u8; len]; + rng.fill(buf.as_mut_slice()); + Bytes::from(buf) + } + + #[tokio::test] + async fn test_framer_roundtrip() { + const MAX_LEN: usize = 1024; + + let (reader, writer) = tokio::io::duplex(MAX_LEN + 8); + let mut reader = FrameReader::new(reader, MAX_LEN); + + let mut writer = Some(writer); + + for _ in 0..1024 { + let body = random_buffer(MAX_LEN); + let mut frame_write = FrameWrite::new(writer.take().unwrap(), body.clone()); + frame_write.send().await.unwrap(); + writer = Some(frame_write.complete()); + + let frame = reader.next().await.unwrap().unwrap(); + assert_eq!(frame, body); + } + } + + // todo: test cancellation, frame size +}