Skip to content

[serde_multipart] serializable Message #852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: gh/mariusae/38/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion serde_multipart/src/de/bincode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ where

fn deserialize_unit_struct<V>(
self,
name: &'static str,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
Expand Down
84 changes: 54 additions & 30 deletions serde_multipart/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,32 @@ use std::ptr::NonNull;
use bincode::Options;
use bytes::Buf;
use bytes::BufMut;
use bytes::Bytes;
use bytes::buf::UninitSlice;

mod de;
mod part;
mod ser;
use bytes::Bytes;
use bytes::BytesMut;
use part::Part;
use serde::Deserialize;
use serde::Serialize;

/// A multi-part message, comprising a message body and a list of parts.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
body: Bytes,
body: Part,
parts: Vec<Part>,
}

impl Message {
/// Returns a new message with the given body and parts.
pub fn from_body_and_parts(body: Bytes, parts: Vec<Part>) -> Self {
pub fn from_body_and_parts(body: Part, parts: Vec<Part>) -> Self {
Self { body, parts }
}

/// The body of the message.
pub fn body(&self) -> &Bytes {
pub fn body(&self) -> &Part {
&self.body
}

Expand All @@ -66,19 +69,24 @@ impl Message {
&self.parts
}

/// Returns the total number of parts (body + number of parts) in the message.
/// Returns the total number of parts (excluding the body) in the message.
pub fn num_parts(&self) -> usize {
self.parts.len()
}

/// Returns the total size (in bytes) of the message.
pub fn len(&self) -> usize {
1 + self.parts.len()
self.body.len() + self.parts.iter().map(|part| part.len()).sum::<usize>()
}

/// Returns whether the message is empty. It is always false, since the body
/// is always defined.
pub fn is_empty(&self) -> bool {
false // there is always a body
self.body.is_empty() && self.parts.iter().all(|part| part.is_empty())
}

/// Convert this message into its constituent components.
pub fn into_inner(self) -> (Bytes, Vec<Part>) {
pub fn into_inner(self) -> (Part, Vec<Part>) {
(self.body, self.parts)
}
}
Expand Down Expand Up @@ -144,30 +152,35 @@ unsafe impl BufMut for UnsafeBufCellRef {
///
/// Serialize uses the same codec options as [`bincode::serialize`] / [`bincode::deserialize`].
/// These are currently not customizable unless an explicit specialization is also provided.
pub fn serialize<S: ?Sized + serde::Serialize>(value: &S) -> Result<Message, bincode::Error> {
pub fn serialize_bincode<S: ?Sized + serde::Serialize>(
value: &S,
) -> Result<Message, bincode::Error> {
let buffer = UnsafeBufCell::from_bytes_mut(BytesMut::new());
// SAFETY: we know here that, once the below "value.serialize()" is done, there are no more
// extant references to this buffer; we are thus safe to reclaim the buffer into the message
let buffer_writer = unsafe { buffer.borrow_unchecked() };
let serializer = bincode::Serializer::new(buffer_writer.writer(), options());
let mut serializer: part::BincodeSerializer = ser::bincode::Serializer::new(serializer);
let buffer_borrow = unsafe { buffer.borrow_unchecked() };
let mut serializer: part::BincodeSerializer =
ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options()));
value.serialize(&mut serializer)?;
Ok(Message {
body: buffer.into_inner().freeze(),
body: Part(buffer.into_inner().freeze()),
parts: serializer.into_parts(),
})
}

/// Deserialize a message serialized by `[serialize]`, stitching together the original
/// message without copying the underlying buffers.
pub fn deserialize<'a, T>(message: Message) -> Result<T, bincode::Error>
pub fn deserialize_bincode<'a, T>(message: Message) -> Result<T, bincode::Error>
where
T: serde::Deserialize<'a>,
{
let (body, parts) = message.into_inner();
let bincode_deserializer = bincode::Deserializer::with_reader(body.reader(), options());
let mut deserializer = part::BincodeDeserializer::new(bincode_deserializer, parts.into());
let mut deserializer = part::BincodeDeserializer::new(
bincode::Deserializer::with_reader(body.into_inner().reader(), options()),
parts.into(),
);
let value = T::deserialize(&mut deserializer)?;
// Check that all parts were consumed:
deserializer.end()?;
Ok(value)
}
Expand All @@ -193,9 +206,9 @@ mod tests {
where
T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
{
let message = serialize(&value).unwrap();
assert_eq!(message.len(), expected_parts);
let deserialized_value = deserialize(message).unwrap();
let message = serialize_bincode(&value).unwrap();
assert_eq!(message.num_parts(), expected_parts);
let deserialized_value = deserialize_bincode(message).unwrap();
assert_eq!(value, deserialized_value);

// Test normal bincode passthrough:
Expand All @@ -206,13 +219,13 @@ mod tests {

#[test]
fn test_specialized_serializer_basic() {
test_roundtrip(Part::from("hello"), 2);
test_roundtrip(Part::from("hello"), 1);
}

#[test]
fn test_specialized_serializer_compound() {
test_roundtrip(vec![Part::from("hello"), Part::from("world")], 3);
test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 3);
test_roundtrip(vec![Part::from("hello"), Part::from("world")], 2);
test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 2);
test_roundtrip(
{
#[derive(Serialize, Deserialize, Debug, PartialEq)]
Expand Down Expand Up @@ -242,7 +255,7 @@ mod tests {
],
}
},
8,
7,
);
test_roundtrip(
{
Expand All @@ -262,29 +275,40 @@ mod tests {
field5: 2,
}
},
3,
2,
);
}

#[test]
fn test_recursive_message() {
let message = serialize_bincode(&[Part::from("hello"), Part::from("world")]).unwrap();
let message_message = serialize_bincode(&message).unwrap();

// message.body + message.parts (x2):
assert_eq!(message_message.num_parts(), 3);
}

#[test]
fn test_malformed_messages() {
let message = Message {
body: Bytes::from_static(b"hello"),
body: Part::from("hello"),
parts: vec![Part::from("world")],
};
let err = deserialize::<String>(message).unwrap_err();
let err = deserialize_bincode::<String>(message).unwrap_err();

// Normal bincode errors work:
assert_matches!(*err, bincode::ErrorKind::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof);

let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap();
let mut message =
serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
message.parts.push(Part::from("foo"));
let err = deserialize::<Vec<Part>>(message).unwrap_err();
let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart overrun while decoding");

let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap();
let mut message =
serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
let _dropped_message = message.parts.pop().unwrap();
let err = deserialize::<Vec<Part>>(message).unwrap_err();
let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding");
}
}
19 changes: 18 additions & 1 deletion serde_multipart/src/part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

use std::ops::Deref;

use bytes::Bytes;
use bytes::buf::Reader as BufReader;
use bytes::buf::Writer as BufWriter;
Expand All @@ -22,14 +24,29 @@ use crate::ser;
/// serialization implementation that is specialized for the multipart codecs in
/// this crate, skipping copying the bytes whenever possible.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Part(Bytes);
pub struct Part(pub(crate) Bytes);

impl Part {
/// Consumes the part, returning its underlying byte buffer.
pub fn into_inner(self) -> Bytes {
self.0
}
}

impl<T: Into<Bytes>> From<T> for Part {
fn from(bytes: T) -> Self {
Self(bytes.into())
}
}

impl Deref for Part {
type Target = Bytes;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Serialize for Part {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
<Part as PartSerializer<S>>::serialize(self, s)
Expand Down
Loading