diff --git a/serde_multipart/Cargo.toml b/serde_multipart/Cargo.toml new file mode 100644 index 000000000..3d9c4cfab --- /dev/null +++ b/serde_multipart/Cargo.toml @@ -0,0 +1,15 @@ +# @generated by autocargo from //monarch/serde_multipart:serde_multipart + +[package] +name = "serde_multipart" +version = "0.0.0" +authors = ["Facebook "] +edition = "2021" +description = "multipart encoding for serde" +repository = "https://github.com/pytorch-labs/monarch/" +license = "BSD-3-Clause" + +[dependencies] +bincode = "1.3.3" +bytes = { version = "1.10", features = ["serde"] } +serde = { version = "1.0.219", features = ["derive", "rc"] } diff --git a/serde_multipart/src/de.rs b/serde_multipart/src/de.rs new file mode 100644 index 000000000..89d0bf9a1 --- /dev/null +++ b/serde_multipart/src/de.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +pub mod bincode; diff --git a/serde_multipart/src/de/bincode.rs b/serde_multipart/src/de/bincode.rs new file mode 100644 index 000000000..044347726 --- /dev/null +++ b/serde_multipart/src/de/bincode.rs @@ -0,0 +1,409 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::VecDeque; +use std::io::Read; + +use bincode::BincodeRead; +use bincode::Error; +use bincode::ErrorKind; +use bincode::Options; +use serde::de::IntoDeserializer; + +use crate::part::Part; + +/// Multipart deserializer for bincode. This passes through to the underlying bincode +/// deserializer, but dequeues serialized parts when they are needed by [`Part::deserialize`]. +pub struct Deserializer { + de: bincode::Deserializer, + parts: VecDeque, +} + +impl Deserializer +where + O: Options, +{ + pub(crate) fn new(de: bincode::Deserializer, parts: VecDeque) -> Self { + Self { de, parts } + } + + pub(crate) fn deserialize_part(&mut self) -> Result { + self.parts.pop_front().ok_or_else(|| { + ErrorKind::Custom("multipart underrun while decoding".to_string()).into() + }) + } + + pub(crate) fn end(self) -> Result<(), Error> { + if self.parts.is_empty() { + Ok(()) + } else { + Err(ErrorKind::Custom("multipart overrun while decoding".to_string()).into()) + } + } +} + +// Passthrough to the underlying bincode deserializer; we only override through specialization. +impl<'de, 'a, R, O> Deserializer +where + R: BincodeRead<'de>, + O: Options, +{ + fn deserialize_len(&mut self) -> Result { + // A visitor that only expects visiting a u64: + struct LenVisitor; + + impl<'de> serde::de::Visitor<'de> for LenVisitor { + type Value = usize; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a 64-bit length") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as usize) + } + } + + use serde::Deserializer; + self.de.deserialize_u64(LenVisitor) + } +} + +impl<'de, 'a, R, O> serde::Deserializer<'de> for &'a mut Deserializer +where + R: BincodeRead<'de>, + O: Options, +{ + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + // Not supported + self.de.deserialize_any(visitor) + } + + fn deserialize_bool(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_bool(visitor) + } + + fn deserialize_i8(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_i8(visitor) + } + + fn deserialize_i16(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_i16(visitor) + } + + fn deserialize_i32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_i32(visitor) + } + + fn deserialize_i64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_i64(visitor) + } + + fn deserialize_u8(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_u8(visitor) + } + + fn deserialize_u16(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_u16(visitor) + } + + fn deserialize_u32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_u32(visitor) + } + + fn deserialize_u64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_u64(visitor) + } + + fn deserialize_f32(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_f32(visitor) + } + + fn deserialize_f64(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_f64(visitor) + } + + fn deserialize_char(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_char(visitor) + } + + fn deserialize_str(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_str(visitor) + } + + fn deserialize_string(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_string(visitor) + } + + fn deserialize_bytes(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_bytes(visitor) + } + + fn deserialize_byte_buf(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_byte_buf(visitor) + } + + fn deserialize_unit(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_unit(visitor) + } + + // Below are compound types, which need to recurse into this handler. + + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let value: u8 = serde::de::Deserialize::deserialize(&mut *self)?; + match value { + 0 => visitor.visit_none(), + 1 => visitor.visit_some(&mut *self), + v => Err(ErrorKind::InvalidTagEncoding(v as usize).into()), + } + } + + fn deserialize_unit_struct( + self, + name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_unit() + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let len = self.deserialize_len()?; + self.deserialize_tuple(len, visitor) + } + + fn deserialize_tuple(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + struct Access<'a, R: Read + 'a, O: Options + 'a> { + deserializer: &'a mut Deserializer, + len: usize, + } + + impl<'de, 'a, 'b: 'a, R: BincodeRead<'de> + 'b, O: Options> serde::de::SeqAccess<'de> + for Access<'a, R, O> + { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result, Error> + where + T: serde::de::DeserializeSeed<'de>, + { + if self.len > 0 { + self.len -= 1; + let value = + serde::de::DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.len) + } + } + + visitor.visit_seq(Access { + deserializer: self, + len, + }) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + len: usize, + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_tuple(len, visitor) + } + + fn deserialize_map(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_map(visitor) + } + + fn deserialize_struct( + self, + _name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + self.deserialize_tuple(fields.len(), visitor) + } + + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + impl<'de, 'a, R: 'a, O> serde::de::EnumAccess<'de> for &'a mut Deserializer + where + R: BincodeRead<'de>, + O: Options, + { + type Error = Error; + type Variant = Self; + + fn variant_seed(self, seed: V) -> Result<(V::Value, Self::Variant), Error> + where + V: serde::de::DeserializeSeed<'de>, + { + let idx: u32 = serde::de::Deserialize::deserialize(&mut *self)?; + let val: Result<_, Error> = + seed.deserialize(IntoDeserializer::into_deserializer(idx)); + Ok((val?, self)) + } + } + + visitor.visit_enum(self) + } + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_identifier(visitor) + } + + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.de.deserialize_ignored_any(visitor) + } +} + +impl<'de, 'a, R, O> serde::de::VariantAccess<'de> for &'a mut Deserializer +where + R: BincodeRead<'de>, + O: Options, +{ + type Error = Error; + + fn unit_variant(self) -> Result<(), Error> { + Ok(()) + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: serde::de::DeserializeSeed<'de>, + { + serde::de::DeserializeSeed::deserialize(seed, self) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + serde::de::Deserializer::deserialize_tuple(self, len, visitor) + } + + fn struct_variant( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + serde::de::Deserializer::deserialize_tuple(self, fields.len(), visitor) + } +} diff --git a/serde_multipart/src/lib.rs b/serde_multipart/src/lib.rs new file mode 100644 index 000000000..63b4477f2 --- /dev/null +++ b/serde_multipart/src/lib.rs @@ -0,0 +1,290 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Serde codec for multipart messages. +//! +//! Using [`serialize`] / [`deserialize`], fields typed [`Part`] are extracted +//! from the main payload and appended to a list of `parts`. Each part is backed by +//! [`bytes::Bytes`] for cheap, zero-copy sharing. +//! +//! On decode, the body and its parts are reassembled into the original value +//! without copying. +//! +//! The on-the-wire form is a [`Message`] (body + parts). Your transport sends +//! and receives [`Message`]s; the codec reconstructs the value, enabling +//! efficient network I/O without compacting data into a single buffer. +//! +//! Implementation note: this crate uses Rust's min_specialization feature to enable +//! the use of [`Part`]s with any Serde serializer or deserializer. This feature +//! is fairly restrictive, and thus the API offered by [`serialize`] / [`deserialize`] +//! is not customizable. If customization is needed, you need to add specialization +//! implementations for these codecs. See [`part::PartSerializer`] and [`part::PartDeserializer`] +//! for details. + +#![feature(min_specialization)] +#![feature(assert_matches)] + +use std::cell::UnsafeCell; +use std::ptr::NonNull; + +use bincode::Options; +use bytes::Buf; +use bytes::BufMut; +use bytes::Bytes; +use bytes::buf::UninitSlice; + +mod de; +mod part; +mod ser; +use bytes::BytesMut; +use part::Part; + +/// A multi-part message, comprising a message body and a list of parts. +pub struct Message { + body: Bytes, + parts: Vec, +} + +impl Message { + /// Returns a new message with the given body and parts. + pub fn from_body_and_parts(body: Bytes, parts: Vec) -> Self { + Self { body, parts } + } + + /// The body of the message. + pub fn body(&self) -> &Bytes { + &self.body + } + + /// The list of parts of the message. + pub fn parts(&self) -> &[Part] { + &self.parts + } + + /// Returns the total number of parts (body + number of parts) in the message. + pub fn len(&self) -> usize { + 1 + self.parts.len() + } + + /// Returns whether the message is empty. It is always false, since the body + /// is always defined. + pub fn is_empty(&self) -> bool { + false // there is always a body + } + + /// Convert this message into its constituent components. + pub fn into_inner(self) -> (Bytes, Vec) { + (self.body, self.parts) + } +} + +/// An unsafe cell of a [`BytesMut`]. This is used to implement an io::Writer +/// for the serializer without exposing lifetime parameters (which cannot be) +/// specialized. +struct UnsafeBufCell { + buf: UnsafeCell, +} + +impl UnsafeBufCell { + /// Create a new cell from a [`BytesMut`]. + fn from_bytes_mut(bytes: BytesMut) -> Self { + Self { + buf: UnsafeCell::new(bytes), + } + } + + /// Convert this cell into its underlying [`BytesMut`]. + fn into_inner(self) -> BytesMut { + self.buf.into_inner() + } + + /// Borrow the cell, without lifetime checks. The caller must guarantee that + /// the returned cell cannot be used after the cell is dropped (usually through + /// [`UnsafeBufCell::into_inner`]). + unsafe fn borrow_unchecked(&self) -> UnsafeBufCellRef { + let ptr = + // SAFETY: the user is providing the necessary invariants + unsafe { NonNull::new_unchecked(self.buf.get()) }; + UnsafeBufCellRef { ptr } + } +} + +/// A borrowed reference to an [`UnsafeBufCell`]. +struct UnsafeBufCellRef { + ptr: NonNull, +} + +/// SAFETY: we're extending the implementation of the underlying [`BytesMut`]; +/// adding an additional layer of danger by disregarding lifetimes. +unsafe impl BufMut for UnsafeBufCellRef { + fn remaining_mut(&self) -> usize { + // SAFETY: extending the implementation of the underlying [`BytesMut`] + unsafe { self.ptr.as_ref().remaining_mut() } + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + // SAFETY: extending the implementation of the underlying [`BytesMut`] + unsafe { self.ptr.as_mut().advance_mut(cnt) } + } + + fn chunk_mut(&mut self) -> &mut UninitSlice { + // SAFETY: extending the implementation of the underlying [`BytesMut`] + unsafe { self.ptr.as_mut().chunk_mut() } + } +} + +/// Serialize the provided value into a multipart message. The value is encoded using an +/// extended version of [`bincode`] that skips serializing [`Part`]s, which are instead +/// held directly by the returned message. +/// +/// Serialize uses the same codec options as [`bincode::serialize`] / [`bincode::deserialize`]. +/// These are currently not customizable unless an explicit specialization is also provided. +pub fn serialize(value: &S) -> Result { + let buffer = UnsafeBufCell::from_bytes_mut(BytesMut::new()); + // SAFETY: we know here that, once the below "value.serialize()" is done, there are no more + // extant references to this buffer; we are thus safe to reclaim the buffer into the message + let buffer_writer = unsafe { buffer.borrow_unchecked() }; + let serializer = bincode::Serializer::new(buffer_writer.writer(), options()); + let mut serializer: part::BincodeSerializer = ser::bincode::Serializer::new(serializer); + value.serialize(&mut serializer)?; + Ok(Message { + body: buffer.into_inner().freeze(), + parts: serializer.into_parts(), + }) +} + +/// Deserialize a message serialized by `[serialize]`, stitching together the original +/// message without copying the underlying buffers. +pub fn deserialize<'a, T>(message: Message) -> Result +where + T: serde::Deserialize<'a>, +{ + let (body, parts) = message.into_inner(); + let bincode_deserializer = bincode::Deserializer::with_reader(body.reader(), options()); + let mut deserializer = part::BincodeDeserializer::new(bincode_deserializer, parts.into()); + let value = T::deserialize(&mut deserializer)?; + deserializer.end()?; + Ok(value) +} + +/// Construct the set of options used by the specialized serializer and deserializer. +fn options() -> part::BincodeOptionsType { + bincode::DefaultOptions::new() + .with_fixint_encoding() + .allow_trailing_bytes() +} + +#[cfg(test)] +mod tests { + use std::assert_matches::assert_matches; + + use serde::Deserialize; + use serde::Serialize; + use serde::de::DeserializeOwned; + + use super::*; + + fn test_roundtrip(value: T, expected_parts: usize) + where + T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug, + { + let message = serialize(&value).unwrap(); + assert_eq!(message.len(), expected_parts); + let deserialized_value = deserialize(message).unwrap(); + assert_eq!(value, deserialized_value); + + // Test normal bincode passthrough: + let bincode_serialized = bincode::serialize(&value).unwrap(); + let bincode_deserialized = bincode::deserialize(&bincode_serialized).unwrap(); + assert_eq!(value, bincode_deserialized); + } + + #[test] + fn test_specialized_serializer_basic() { + test_roundtrip(Part::from("hello"), 2); + } + + #[test] + fn test_specialized_serializer_compound() { + test_roundtrip(vec![Part::from("hello"), Part::from("world")], 3); + test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 3); + test_roundtrip( + { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct U { + parts: Vec, + } + + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct T { + field2: String, + field3: Part, + field4: Part, + field5: Vec, + } + + T { + field2: "hello".to_string(), + field3: Part::from("hello"), + field4: Part::from("world"), + field5: vec![ + U { + parts: vec![Part::from("hello"), Part::from("world")], + }, + U { + parts: vec![Part::from("five"), Part::from("six"), Part::from("seven")], + }, + ], + } + }, + 8, + ); + test_roundtrip( + { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + struct T { + field1: u64, + field2: String, + field3: Part, + field4: Part, + field5: u64, + } + T { + field1: 1, + field2: "hello".to_string(), + field3: Part::from("hello"), + field4: Part::from("world"), + field5: 2, + } + }, + 3, + ); + } + + #[test] + fn test_malformed_messages() { + let message = Message { + body: Bytes::from_static(b"hello"), + parts: vec![Part::from("world")], + }; + let err = deserialize::(message).unwrap_err(); + + // Normal bincode errors work: + assert_matches!(*err, bincode::ErrorKind::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof); + + let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap(); + message.parts.push(Part::from("foo")); + let err = deserialize::>(message).unwrap_err(); + assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart overrun while decoding"); + + let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap(); + let _dropped_message = message.parts.pop().unwrap(); + let err = deserialize::>(message).unwrap_err(); + assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding"); + } +} diff --git a/serde_multipart/src/part.rs b/serde_multipart/src/part.rs new file mode 100644 index 000000000..2d231e6cb --- /dev/null +++ b/serde_multipart/src/part.rs @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use bytes::Bytes; +use bytes::buf::Reader as BufReader; +use bytes::buf::Writer as BufWriter; +use serde::Deserialize; +use serde::Serialize; + +use crate::UnsafeBufCellRef; +use crate::de; +use crate::ser; + +/// Part represents a single part of a multipart message. Its type is simple: +/// it is just a newtype of the byte buffer [`Bytes`], which permits zero copy +/// shared ownership of the underlying buffers. Part itself provides a customized +/// serialization implementation that is specialized for the multipart codecs in +/// this crate, skipping copying the bytes whenever possible. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Part(Bytes); + +impl> From for Part { + fn from(bytes: T) -> Self { + Self(bytes.into()) + } +} + +impl Serialize for Part { + fn serialize(&self, s: S) -> Result { + >::serialize(self, s) + } +} + +impl<'de> Deserialize<'de> for Part { + fn deserialize>(d: D) -> Result { + >::deserialize(d) + } +} + +/// PartSerializer is the trait that selects serialization strategy based on the +/// the serializer's type. +pub trait PartSerializer { + fn serialize(this: &Part, s: S) -> Result; +} + +/// By default, we use the underlying byte serializer, which copies the underlying bytes +/// into the serialization buffer. +impl PartSerializer for Part { + default fn serialize(this: &Part, s: S) -> Result { + // Normal serializer: contiguous byte chunk, but requires copy. + this.0.serialize(s) + } +} + +/// The options type used by the underlying bincode codec. We capture this here to make sure +/// we consistently use the type, which is required to correctly specialize the multipart codec. +pub(crate) type BincodeOptionsType = bincode::config::WithOtherTrailing< + bincode::config::WithOtherIntEncoding, + bincode::config::AllowTrailing, +>; + +/// The serializer type used by the underlying bincode codec. We capture this here to make sure +/// we consistently use the type, which is required to correctly specialize the multipart codec. +pub(crate) type BincodeSerializer = + ser::bincode::Serializer, BincodeOptionsType>; + +/// Specialized implementaiton for our multipart serializer. +impl<'a> PartSerializer<&'a mut BincodeSerializer> for Part { + fn serialize(this: &Part, s: &'a mut BincodeSerializer) -> Result<(), bincode::Error> { + s.serialize_part(this); + Ok(()) + } +} + +/// PartDeserializer is the trait that selects serialization strategy based on the +/// the deserializer's type. +trait PartDeserializer<'de, S: serde::Deserializer<'de>>: Sized { + fn deserialize(this: S) -> Result; +} + +/// By default, we use the underlying byte deserializer, which copies the serialized bytes +/// into the value directly. +impl<'de, D: serde::Deserializer<'de>> PartDeserializer<'de, D> for Part { + default fn deserialize(deserializer: D) -> Result { + Ok(Part(Bytes::deserialize(deserializer)?)) + } +} + +/// The deserializer type used by the underlying bincode codec. We capture this here to make sure +/// we consistently use the type, which is required to correctly specialize the multipart codec. +pub(crate) type BincodeDeserializer = + de::bincode::Deserializer>, BincodeOptionsType>; + +/// Specialized implementation for our multipart deserializer. +impl<'de, 'a> PartDeserializer<'de, &'a mut BincodeDeserializer> for Part { + fn deserialize(deserializer: &'a mut BincodeDeserializer) -> Result { + deserializer.deserialize_part() + } +} diff --git a/serde_multipart/src/ser.rs b/serde_multipart/src/ser.rs new file mode 100644 index 000000000..89d0bf9a1 --- /dev/null +++ b/serde_multipart/src/ser.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +pub mod bincode; diff --git a/serde_multipart/src/ser/bincode.rs b/serde_multipart/src/ser/bincode.rs new file mode 100644 index 000000000..828f02a82 --- /dev/null +++ b/serde_multipart/src/ser/bincode.rs @@ -0,0 +1,350 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::io::Write; + +use ::bincode::Error; +use ::bincode::Options; +use serde::Serialize; +use serde::ser; + +use crate::Part; + +/// Multipart serializer for bincode. This passes through serialization to bincode, +/// but also records the parts encoded by [`Part::serialize`]. +pub struct Serializer { + ser: ::bincode::Serializer, + parts: Vec, +} + +impl Serializer { + pub(crate) fn new(ser: ::bincode::Serializer) -> Self { + Self { + ser, + parts: Vec::new(), + } + } + + /// Serialize a part by appending it to the parts list. + pub(crate) fn serialize_part(&mut self, part: &Part) { + self.parts.push(part.clone()); + } + + pub(crate) fn into_parts(self) -> Vec { + self.parts + } +} + +pub struct Compound<'a, W, O: Options> { + ser: &'a mut Serializer, +} + +impl<'a, W: Write, O: Options> ser::Serializer for &'a mut Serializer { + type Ok = (); + type Error = Error; + + type SerializeSeq = Compound<'a, W, O>; + type SerializeTuple = Compound<'a, W, O>; + type SerializeTupleStruct = Compound<'a, W, O>; + type SerializeTupleVariant = Compound<'a, W, O>; + type SerializeMap = Compound<'a, W, O>; + type SerializeStruct = Compound<'a, W, O>; + type SerializeStructVariant = Compound<'a, W, O>; + + fn serialize_bool(self, v: bool) -> Result { + self.ser.serialize_bool(v) + } + + fn serialize_i8(self, v: i8) -> Result { + self.ser.serialize_i8(v) + } + + fn serialize_i16(self, v: i16) -> Result { + self.ser.serialize_i16(v) + } + + fn serialize_i32(self, v: i32) -> Result { + self.ser.serialize_i32(v) + } + + fn serialize_i64(self, v: i64) -> Result { + self.ser.serialize_i64(v) + } + + fn serialize_u8(self, v: u8) -> Result { + self.ser.serialize_u8(v) + } + + fn serialize_u16(self, v: u16) -> Result { + self.ser.serialize_u16(v) + } + + fn serialize_u32(self, v: u32) -> Result { + self.ser.serialize_u32(v) + } + + fn serialize_u64(self, v: u64) -> Result { + self.ser.serialize_u64(v) + } + + fn serialize_f32(self, v: f32) -> Result { + self.ser.serialize_f32(v) + } + + fn serialize_f64(self, v: f64) -> Result { + self.ser.serialize_f64(v) + } + + fn serialize_char(self, v: char) -> Result { + self.ser.serialize_char(v) + } + + fn serialize_str(self, v: &str) -> Result { + self.ser.serialize_str(v) + } + + fn serialize_bytes(self, v: &[u8]) -> Result { + self.ser.serialize_bytes(v) + } + + fn serialize_none(self) -> Result { + self.ser.serialize_none() + } + + // The following are compounds which require special care to recurse into + // our implementation: + + fn serialize_some(self, v: &T) -> Result + where + T: ?Sized + Serialize, + { + self.ser.serialize_u8(1)?; + v.serialize(self) + } + + fn serialize_unit(self) -> Result { + Ok(()) + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + Ok(()) + } + + fn serialize_unit_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + ) -> Result { + self.serialize_u32(variant_index)?; + Ok(()) + } + + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + value.serialize(self) + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + variant_index: u32, + _variant: &'static str, + value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + self.ser.serialize_u32(variant_index)?; + value.serialize(self) + } + + fn serialize_seq(self, len: Option) -> Result { + let _ = self.ser.serialize_seq(len)?; + Ok(Compound { ser: self }) + } + + fn serialize_tuple(self, len: usize) -> Result { + let _ = self.ser.serialize_tuple(len)?; + Ok(Compound { ser: self }) + } + + fn serialize_tuple_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + let _ = self.ser.serialize_tuple_struct(name, len)?; + Ok(Compound { ser: self }) + } + + fn serialize_tuple_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + let _ = self + .ser + .serialize_tuple_variant(name, variant_index, variant, len)?; + Ok(Compound { ser: self }) + } + + fn serialize_map(self, len: Option) -> Result { + let _ = self.ser.serialize_map(len)?; + Ok(Compound { ser: self }) + } + + fn serialize_struct( + self, + name: &'static str, + len: usize, + ) -> Result { + let _ = self.ser.serialize_struct(name, len)?; + Ok(Compound { ser: self }) + } + + fn serialize_struct_variant( + self, + name: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result { + let _ = self + .ser + .serialize_struct_variant(name, variant_index, variant, len)?; + Ok(Compound { ser: self }) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeSeq for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeTuple for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_element(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeTupleStruct for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeTupleVariant for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeMap for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_key(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn serialize_value(&mut self, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeStruct for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +} + +impl<'a, W: Write, O: Options> ser::SerializeStructVariant for Compound<'a, W, O> { + type Ok = (); + type Error = Error; + + fn serialize_field(&mut self, _key: &'static str, value: &T) -> Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + value.serialize(&mut *self.ser) + } + + fn end(self) -> Result { + Ok(()) + } +}