diff --git a/serde_multipart/src/lib.rs b/serde_multipart/src/lib.rs index 35e4e3973..01cd6ccba 100644 --- a/serde_multipart/src/lib.rs +++ b/serde_multipart/src/lib.rs @@ -28,8 +28,12 @@ #![feature(min_specialization)] #![feature(assert_matches)] +#![feature(vec_deque_pop_if)] use std::cell::UnsafeCell; +use std::cmp::min; +use std::collections::VecDeque; +use std::io::IoSlice; use std::ptr::NonNull; use bincode::Options; @@ -47,6 +51,8 @@ use serde::Deserialize; use serde::Serialize; /// A multi-part message, comprising a message body and a list of parts. +/// Messages only contain references to underlying byte buffers and are +/// cheaply cloned. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Message { body: Part, @@ -89,6 +95,108 @@ impl Message { pub fn into_inner(self) -> (Part, Vec) { (self.body, self.parts) } + + /// Efficiently frames a message containing the body and all of its parts + /// using a simple frame-length encoding: + /// + /// ```text + /// +--------------------+-------------------+--------------------+-------------------+ ... + + /// | body_len (u64 BE) | body bytes | part1_len (u64 BE) | part1 bytes | | + /// +--------------------+-------------------+--------------------+-------------------+ + + /// repeat + /// for + /// each part + /// ``` + pub fn framed(self) -> impl Buf { + let (body, parts) = self.into_inner(); + let mut buffers = Vec::with_capacity(2 + 2 * parts.len()); + + let body = body.into_inner(); + buffers.push(Bytes::from_owner(body.len().to_be_bytes())); + buffers.push(body); + + for part in parts { + let part = part.into_inner(); + buffers.push(Bytes::from_owner(part.len().to_be_bytes())); + buffers.push(part); + } + + ConcatBuf::from_buffers(buffers) + } + + /// Reassembles a message from a framed encoding. + pub fn from_framed(mut buf: Bytes) -> Result { + let body = Self::split_part(&mut buf)?.into(); + let mut parts = Vec::new(); + while buf.len() > 0 { + parts.push(Self::split_part(&mut buf)?.into()); + } + Ok(Self { body, parts }) + } + + fn split_part(buf: &mut Bytes) -> Result { + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + let at = buf.get_u64() as usize; + if buf.len() < at { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + Ok(buf.split_to(at)) + } +} + +struct ConcatBuf { + buffers: VecDeque, +} + +impl ConcatBuf { + /// Construct a new concatenated buffer. + fn from_buffers(buffers: Vec) -> Self { + let mut buffers: VecDeque = buffers.into(); + buffers.retain(|buf| !buf.is_empty()); + Self { buffers } + } +} + +impl Buf for ConcatBuf { + fn remaining(&self) -> usize { + self.buffers.iter().map(|buf| buf.remaining()).sum() + } + + fn chunk(&self) -> &[u8] { + match self.buffers.front() { + Some(buf) => buf.chunk(), + None => &[], + } + } + + fn advance(&mut self, mut cnt: usize) { + while cnt > 0 { + let Some(buf) = self.buffers.front_mut() else { + panic!("advanced beyond the buffer size"); + }; + + if cnt >= buf.remaining() { + cnt -= buf.remaining(); + self.buffers.pop_front(); + continue; + } + + buf.advance(cnt); + cnt = 0; + } + } + + // We implement our own chunks_vectored here, as the default implementation + // does not do any vectoring (returning only a single IoSlice at a time). + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + let n = min(dst.len(), self.buffers.len()); + for i in 0..n { + dst[i] = IoSlice::new(self.buffers[i].chunk()); + } + n + } } /// An unsafe cell of a [`BytesMut`]. This is used to implement an io::Writer @@ -206,12 +314,19 @@ mod tests { where T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug, { + // Test plain serialization roundtrip: let message = serialize_bincode(&value).unwrap(); assert_eq!(message.num_parts(), expected_parts); - let deserialized_value = deserialize_bincode(message).unwrap(); + let deserialized_value = deserialize_bincode(message.clone()).unwrap(); assert_eq!(value, deserialized_value); - // Test normal bincode passthrough: + // Framing roundtrip: + let mut framed = message.clone().framed(); + let framed = framed.copy_to_bytes(framed.remaining()); + let unframed_message = Message::from_framed(framed).unwrap(); + assert_eq!(message, unframed_message); + + // Bincode passthrough: let bincode_serialized = bincode::serialize(&value).unwrap(); let bincode_deserialized = bincode::deserialize(&bincode_serialized).unwrap(); assert_eq!(value, bincode_deserialized); @@ -311,4 +426,50 @@ mod tests { let err = deserialize_bincode::>(message).unwrap_err(); assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding"); } + + #[test] + fn test_concat_buf() { + let buffers = vec![ + Bytes::from("hello"), + Bytes::from("world"), + Bytes::from("1"), + Bytes::from(""), + Bytes::from("xyz"), + Bytes::from("xyzd"), + ]; + + let mut concat = ConcatBuf::from_buffers(buffers.clone()); + + assert_eq!(concat.remaining(), 18); + concat.advance(2); + assert_eq!(concat.remaining(), 16); + assert_eq!(concat.chunk(), &b"llo"[..]); + concat.advance(4); + assert_eq!(concat.chunk(), &b"orld"[..]); + concat.advance(5); + assert_eq!(concat.chunk(), &b"xyz"[..]); + + let mut concat = ConcatBuf::from_buffers(buffers); + let bytes = concat.copy_to_bytes(concat.remaining()); + assert_eq!(&*bytes, &b"helloworld1xyzxyzd"[..]); + } + + #[test] + fn test_framing() { + let message = Message { + body: Part::from("hello"), + parts: vec![ + Part::from("world"), + Part::from("1"), + Part::from(""), + Part::from("xyz"), + Part::from("xyzd"), + ] + .into(), + }; + + let mut framed = message.clone().framed(); + let framed = framed.copy_to_bytes(framed.remaining()); + assert_eq!(Message::from_framed(framed).unwrap(), message); + } } diff --git a/serde_multipart/src/part.rs b/serde_multipart/src/part.rs index ed78c528c..2f3628b10 100644 --- a/serde_multipart/src/part.rs +++ b/serde_multipart/src/part.rs @@ -31,6 +31,11 @@ impl Part { pub fn into_inner(self) -> Bytes { self.0 } + + /// Returns a reference to the underlying byte buffer. + pub fn to_bytes(&self) -> Bytes { + self.0.clone() + } } impl> From for Part {