Skip to content

[serde_multipart] framed multipart encoding #876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: gh/mariusae/39/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 163 additions & 2 deletions serde_multipart/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@

#![feature(min_specialization)]
#![feature(assert_matches)]
#![feature(vec_deque_pop_if)]

use std::cell::UnsafeCell;
use std::cmp::min;
use std::collections::VecDeque;
use std::io::IoSlice;
use std::ptr::NonNull;

use bincode::Options;
Expand All @@ -47,6 +51,8 @@ use serde::Deserialize;
use serde::Serialize;

/// A multi-part message, comprising a message body and a list of parts.
/// Messages only contain references to underlying byte buffers and are
/// cheaply cloned.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
body: Part,
Expand Down Expand Up @@ -89,6 +95,108 @@ impl Message {
pub fn into_inner(self) -> (Part, Vec<Part>) {
(self.body, self.parts)
}

/// Efficiently frames a message containing the body and all of its parts
/// using a simple frame-length encoding:
///
/// ```text
/// +--------------------+-------------------+--------------------+-------------------+ ... +
/// | body_len (u64 BE) | body bytes | part1_len (u64 BE) | part1 bytes | |
/// +--------------------+-------------------+--------------------+-------------------+ +
/// repeat
/// for
/// each part
/// ```
pub fn framed(self) -> impl Buf {
let (body, parts) = self.into_inner();
let mut buffers = Vec::with_capacity(2 + 2 * parts.len());

let body = body.into_inner();
buffers.push(Bytes::from_owner(body.len().to_be_bytes()));
buffers.push(body);

for part in parts {
let part = part.into_inner();
buffers.push(Bytes::from_owner(part.len().to_be_bytes()));
buffers.push(part);
}

ConcatBuf::from_buffers(buffers)
}

/// Reassembles a message from a framed encoding.
pub fn from_framed(mut buf: Bytes) -> Result<Self, std::io::Error> {
let body = Self::split_part(&mut buf)?.into();
let mut parts = Vec::new();
while buf.len() > 0 {
parts.push(Self::split_part(&mut buf)?.into());
}
Ok(Self { body, parts })
}

fn split_part(buf: &mut Bytes) -> Result<Bytes, std::io::Error> {
if buf.len() < 8 {
return Err(std::io::ErrorKind::UnexpectedEof.into());
}
let at = buf.get_u64() as usize;
if buf.len() < at {
return Err(std::io::ErrorKind::UnexpectedEof.into());
}
Ok(buf.split_to(at))
}
}

struct ConcatBuf {
buffers: VecDeque<Bytes>,
}

impl ConcatBuf {
/// Construct a new concatenated buffer.
fn from_buffers(buffers: Vec<Bytes>) -> Self {
let mut buffers: VecDeque<Bytes> = buffers.into();
buffers.retain(|buf| !buf.is_empty());
Self { buffers }
}
}

impl Buf for ConcatBuf {
fn remaining(&self) -> usize {
self.buffers.iter().map(|buf| buf.remaining()).sum()
}

fn chunk(&self) -> &[u8] {
match self.buffers.front() {
Some(buf) => buf.chunk(),
None => &[],
}
}

fn advance(&mut self, mut cnt: usize) {
while cnt > 0 {
let Some(buf) = self.buffers.front_mut() else {
panic!("advanced beyond the buffer size");
};

if cnt >= buf.remaining() {
cnt -= buf.remaining();
self.buffers.pop_front();
continue;
}

buf.advance(cnt);
cnt = 0;
}
}

// We implement our own chunks_vectored here, as the default implementation
// does not do any vectoring (returning only a single IoSlice at a time).
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
let n = min(dst.len(), self.buffers.len());
for i in 0..n {
dst[i] = IoSlice::new(self.buffers[i].chunk());
}
n
}
}

/// An unsafe cell of a [`BytesMut`]. This is used to implement an io::Writer
Expand Down Expand Up @@ -206,12 +314,19 @@ mod tests {
where
T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
{
// Test plain serialization roundtrip:
let message = serialize_bincode(&value).unwrap();
assert_eq!(message.num_parts(), expected_parts);
let deserialized_value = deserialize_bincode(message).unwrap();
let deserialized_value = deserialize_bincode(message.clone()).unwrap();
assert_eq!(value, deserialized_value);

// Test normal bincode passthrough:
// Framing roundtrip:
let mut framed = message.clone().framed();
let framed = framed.copy_to_bytes(framed.remaining());
let unframed_message = Message::from_framed(framed).unwrap();
assert_eq!(message, unframed_message);

// Bincode passthrough:
let bincode_serialized = bincode::serialize(&value).unwrap();
let bincode_deserialized = bincode::deserialize(&bincode_serialized).unwrap();
assert_eq!(value, bincode_deserialized);
Expand Down Expand Up @@ -311,4 +426,50 @@ mod tests {
let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding");
}

#[test]
fn test_concat_buf() {
let buffers = vec![
Bytes::from("hello"),
Bytes::from("world"),
Bytes::from("1"),
Bytes::from(""),
Bytes::from("xyz"),
Bytes::from("xyzd"),
];

let mut concat = ConcatBuf::from_buffers(buffers.clone());

assert_eq!(concat.remaining(), 18);
concat.advance(2);
assert_eq!(concat.remaining(), 16);
assert_eq!(concat.chunk(), &b"llo"[..]);
concat.advance(4);
assert_eq!(concat.chunk(), &b"orld"[..]);
concat.advance(5);
assert_eq!(concat.chunk(), &b"xyz"[..]);

let mut concat = ConcatBuf::from_buffers(buffers);
let bytes = concat.copy_to_bytes(concat.remaining());
assert_eq!(&*bytes, &b"helloworld1xyzxyzd"[..]);
}

#[test]
fn test_framing() {
let message = Message {
body: Part::from("hello"),
parts: vec![
Part::from("world"),
Part::from("1"),
Part::from(""),
Part::from("xyz"),
Part::from("xyzd"),
]
.into(),
};

let mut framed = message.clone().framed();
let framed = framed.copy_to_bytes(framed.remaining());
assert_eq!(Message::from_framed(framed).unwrap(), message);
}
}
5 changes: 5 additions & 0 deletions serde_multipart/src/part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ impl Part {
pub fn into_inner(self) -> Bytes {
self.0
}

/// Returns a reference to the underlying byte buffer.
pub fn to_bytes(&self) -> Bytes {
self.0.clone()
}
}

impl<T: Into<Bytes>> From<T> for Part {
Expand Down