diff --git a/serde_multipart/src/de/bincode.rs b/serde_multipart/src/de/bincode.rs index 044347726..9c04c6c57 100644 --- a/serde_multipart/src/de/bincode.rs +++ b/serde_multipart/src/de/bincode.rs @@ -227,7 +227,7 @@ where fn deserialize_unit_struct( self, - name: &'static str, + _name: &'static str, visitor: V, ) -> Result where diff --git a/serde_multipart/src/lib.rs b/serde_multipart/src/lib.rs index 63b4477f2..35e4e3973 100644 --- a/serde_multipart/src/lib.rs +++ b/serde_multipart/src/lib.rs @@ -35,29 +35,32 @@ use std::ptr::NonNull; use bincode::Options; use bytes::Buf; use bytes::BufMut; -use bytes::Bytes; use bytes::buf::UninitSlice; mod de; mod part; mod ser; +use bytes::Bytes; use bytes::BytesMut; use part::Part; +use serde::Deserialize; +use serde::Serialize; /// A multi-part message, comprising a message body and a list of parts. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Message { - body: Bytes, + body: Part, parts: Vec, } impl Message { /// Returns a new message with the given body and parts. - pub fn from_body_and_parts(body: Bytes, parts: Vec) -> Self { + pub fn from_body_and_parts(body: Part, parts: Vec) -> Self { Self { body, parts } } /// The body of the message. - pub fn body(&self) -> &Bytes { + pub fn body(&self) -> &Part { &self.body } @@ -66,19 +69,24 @@ impl Message { &self.parts } - /// Returns the total number of parts (body + number of parts) in the message. + /// Returns the total number of parts (excluding the body) in the message. + pub fn num_parts(&self) -> usize { + self.parts.len() + } + + /// Returns the total size (in bytes) of the message. pub fn len(&self) -> usize { - 1 + self.parts.len() + self.body.len() + self.parts.iter().map(|part| part.len()).sum::() } /// Returns whether the message is empty. It is always false, since the body /// is always defined. pub fn is_empty(&self) -> bool { - false // there is always a body + self.body.is_empty() && self.parts.iter().all(|part| part.is_empty()) } /// Convert this message into its constituent components. - pub fn into_inner(self) -> (Bytes, Vec) { + pub fn into_inner(self) -> (Part, Vec) { (self.body, self.parts) } } @@ -144,30 +152,35 @@ unsafe impl BufMut for UnsafeBufCellRef { /// /// Serialize uses the same codec options as [`bincode::serialize`] / [`bincode::deserialize`]. /// These are currently not customizable unless an explicit specialization is also provided. -pub fn serialize(value: &S) -> Result { +pub fn serialize_bincode( + value: &S, +) -> Result { let buffer = UnsafeBufCell::from_bytes_mut(BytesMut::new()); // SAFETY: we know here that, once the below "value.serialize()" is done, there are no more // extant references to this buffer; we are thus safe to reclaim the buffer into the message - let buffer_writer = unsafe { buffer.borrow_unchecked() }; - let serializer = bincode::Serializer::new(buffer_writer.writer(), options()); - let mut serializer: part::BincodeSerializer = ser::bincode::Serializer::new(serializer); + let buffer_borrow = unsafe { buffer.borrow_unchecked() }; + let mut serializer: part::BincodeSerializer = + ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options())); value.serialize(&mut serializer)?; Ok(Message { - body: buffer.into_inner().freeze(), + body: Part(buffer.into_inner().freeze()), parts: serializer.into_parts(), }) } /// Deserialize a message serialized by `[serialize]`, stitching together the original /// message without copying the underlying buffers. -pub fn deserialize<'a, T>(message: Message) -> Result +pub fn deserialize_bincode<'a, T>(message: Message) -> Result where T: serde::Deserialize<'a>, { let (body, parts) = message.into_inner(); - let bincode_deserializer = bincode::Deserializer::with_reader(body.reader(), options()); - let mut deserializer = part::BincodeDeserializer::new(bincode_deserializer, parts.into()); + let mut deserializer = part::BincodeDeserializer::new( + bincode::Deserializer::with_reader(body.into_inner().reader(), options()), + parts.into(), + ); let value = T::deserialize(&mut deserializer)?; + // Check that all parts were consumed: deserializer.end()?; Ok(value) } @@ -193,9 +206,9 @@ mod tests { where T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug, { - let message = serialize(&value).unwrap(); - assert_eq!(message.len(), expected_parts); - let deserialized_value = deserialize(message).unwrap(); + let message = serialize_bincode(&value).unwrap(); + assert_eq!(message.num_parts(), expected_parts); + let deserialized_value = deserialize_bincode(message).unwrap(); assert_eq!(value, deserialized_value); // Test normal bincode passthrough: @@ -206,13 +219,13 @@ mod tests { #[test] fn test_specialized_serializer_basic() { - test_roundtrip(Part::from("hello"), 2); + test_roundtrip(Part::from("hello"), 1); } #[test] fn test_specialized_serializer_compound() { - test_roundtrip(vec![Part::from("hello"), Part::from("world")], 3); - test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 3); + test_roundtrip(vec![Part::from("hello"), Part::from("world")], 2); + test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 2); test_roundtrip( { #[derive(Serialize, Deserialize, Debug, PartialEq)] @@ -242,7 +255,7 @@ mod tests { ], } }, - 8, + 7, ); test_roundtrip( { @@ -262,29 +275,40 @@ mod tests { field5: 2, } }, - 3, + 2, ); } + #[test] + fn test_recursive_message() { + let message = serialize_bincode(&[Part::from("hello"), Part::from("world")]).unwrap(); + let message_message = serialize_bincode(&message).unwrap(); + + // message.body + message.parts (x2): + assert_eq!(message_message.num_parts(), 3); + } + #[test] fn test_malformed_messages() { let message = Message { - body: Bytes::from_static(b"hello"), + body: Part::from("hello"), parts: vec![Part::from("world")], }; - let err = deserialize::(message).unwrap_err(); + let err = deserialize_bincode::(message).unwrap_err(); // Normal bincode errors work: assert_matches!(*err, bincode::ErrorKind::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof); - let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap(); + let mut message = + serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap(); message.parts.push(Part::from("foo")); - let err = deserialize::>(message).unwrap_err(); + let err = deserialize_bincode::>(message).unwrap_err(); assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart overrun while decoding"); - let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap(); + let mut message = + serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap(); let _dropped_message = message.parts.pop().unwrap(); - let err = deserialize::>(message).unwrap_err(); + let err = deserialize_bincode::>(message).unwrap_err(); assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding"); } } diff --git a/serde_multipart/src/part.rs b/serde_multipart/src/part.rs index 2d231e6cb..ed78c528c 100644 --- a/serde_multipart/src/part.rs +++ b/serde_multipart/src/part.rs @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +use std::ops::Deref; + use bytes::Bytes; use bytes::buf::Reader as BufReader; use bytes::buf::Writer as BufWriter; @@ -22,7 +24,14 @@ use crate::ser; /// serialization implementation that is specialized for the multipart codecs in /// this crate, skipping copying the bytes whenever possible. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct Part(Bytes); +pub struct Part(pub(crate) Bytes); + +impl Part { + /// Consumes the part, returning its underlying byte buffer. + pub fn into_inner(self) -> Bytes { + self.0 + } +} impl> From for Part { fn from(bytes: T) -> Self { @@ -30,6 +39,14 @@ impl> From for Part { } } +impl Deref for Part { + type Target = Bytes; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Serialize for Part { fn serialize(&self, s: S) -> Result { >::serialize(self, s)