Skip to content

Commit 8783501

Browse files
committed
[serde_multipart] serializable Message
We want to be able to (efficiently) serialize `Message`, passing through its constituent parts; and also to treat the message body as its own part. In this change, we derive `Serialize` for `Message`, and convert its body into a `Part`. Differential Revision: [D80179921](https://our.internmc.facebook.com/intern/diff/D80179921/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D80179921/)! ghstack-source-id: 302736715 Pull Request resolved: #852
1 parent e57aadf commit 8783501

File tree

3 files changed

+73
-32
lines changed

3 files changed

+73
-32
lines changed

serde_multipart/src/de/bincode.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ where
227227

228228
fn deserialize_unit_struct<V>(
229229
self,
230-
name: &'static str,
230+
_name: &'static str,
231231
visitor: V,
232232
) -> Result<V::Value, Self::Error>
233233
where

serde_multipart/src/lib.rs

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,32 @@ use std::ptr::NonNull;
3535
use bincode::Options;
3636
use bytes::Buf;
3737
use bytes::BufMut;
38-
use bytes::Bytes;
3938
use bytes::buf::UninitSlice;
4039

4140
mod de;
4241
mod part;
4342
mod ser;
43+
use bytes::Bytes;
4444
use bytes::BytesMut;
4545
use part::Part;
46+
use serde::Deserialize;
47+
use serde::Serialize;
4648

4749
/// A multi-part message, comprising a message body and a list of parts.
50+
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
4851
pub struct Message {
49-
body: Bytes,
52+
body: Part,
5053
parts: Vec<Part>,
5154
}
5255

5356
impl Message {
5457
/// Returns a new message with the given body and parts.
55-
pub fn from_body_and_parts(body: Bytes, parts: Vec<Part>) -> Self {
58+
pub fn from_body_and_parts(body: Part, parts: Vec<Part>) -> Self {
5659
Self { body, parts }
5760
}
5861

5962
/// The body of the message.
60-
pub fn body(&self) -> &Bytes {
63+
pub fn body(&self) -> &Part {
6164
&self.body
6265
}
6366

@@ -66,19 +69,24 @@ impl Message {
6669
&self.parts
6770
}
6871

69-
/// Returns the total number of parts (body + number of parts) in the message.
72+
/// Returns the total number of parts (excluding the body) in the message.
73+
pub fn num_parts(&self) -> usize {
74+
self.parts.len()
75+
}
76+
77+
/// Returns the total size (in bytes) of the message.
7078
pub fn len(&self) -> usize {
71-
1 + self.parts.len()
79+
self.body.len() + self.parts.iter().map(|part| part.len()).sum::<usize>()
7280
}
7381

7482
/// Returns whether the message is empty. It is always false, since the body
7583
/// is always defined.
7684
pub fn is_empty(&self) -> bool {
77-
false // there is always a body
85+
self.body.is_empty() && self.parts.iter().all(|part| part.is_empty())
7886
}
7987

8088
/// Convert this message into its constituent components.
81-
pub fn into_inner(self) -> (Bytes, Vec<Part>) {
89+
pub fn into_inner(self) -> (Part, Vec<Part>) {
8290
(self.body, self.parts)
8391
}
8492
}
@@ -144,30 +152,35 @@ unsafe impl BufMut for UnsafeBufCellRef {
144152
///
145153
/// Serialize uses the same codec options as [`bincode::serialize`] / [`bincode::deserialize`].
146154
/// These are currently not customizable unless an explicit specialization is also provided.
147-
pub fn serialize<S: ?Sized + serde::Serialize>(value: &S) -> Result<Message, bincode::Error> {
155+
pub fn serialize_bincode<S: ?Sized + serde::Serialize>(
156+
value: &S,
157+
) -> Result<Message, bincode::Error> {
148158
let buffer = UnsafeBufCell::from_bytes_mut(BytesMut::new());
149159
// SAFETY: we know here that, once the below "value.serialize()" is done, there are no more
150160
// extant references to this buffer; we are thus safe to reclaim the buffer into the message
151-
let buffer_writer = unsafe { buffer.borrow_unchecked() };
152-
let serializer = bincode::Serializer::new(buffer_writer.writer(), options());
153-
let mut serializer: part::BincodeSerializer = ser::bincode::Serializer::new(serializer);
161+
let buffer_borrow = unsafe { buffer.borrow_unchecked() };
162+
let mut serializer: part::BincodeSerializer =
163+
ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options()));
154164
value.serialize(&mut serializer)?;
155165
Ok(Message {
156-
body: buffer.into_inner().freeze(),
166+
body: Part(buffer.into_inner().freeze()),
157167
parts: serializer.into_parts(),
158168
})
159169
}
160170

161171
/// Deserialize a message serialized by `[serialize]`, stitching together the original
162172
/// message without copying the underlying buffers.
163-
pub fn deserialize<'a, T>(message: Message) -> Result<T, bincode::Error>
173+
pub fn deserialize_bincode<'a, T>(message: Message) -> Result<T, bincode::Error>
164174
where
165175
T: serde::Deserialize<'a>,
166176
{
167177
let (body, parts) = message.into_inner();
168-
let bincode_deserializer = bincode::Deserializer::with_reader(body.reader(), options());
169-
let mut deserializer = part::BincodeDeserializer::new(bincode_deserializer, parts.into());
178+
let mut deserializer = part::BincodeDeserializer::new(
179+
bincode::Deserializer::with_reader(body.into_inner().reader(), options()),
180+
parts.into(),
181+
);
170182
let value = T::deserialize(&mut deserializer)?;
183+
// Check that all parts were consumed:
171184
deserializer.end()?;
172185
Ok(value)
173186
}
@@ -193,9 +206,9 @@ mod tests {
193206
where
194207
T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
195208
{
196-
let message = serialize(&value).unwrap();
197-
assert_eq!(message.len(), expected_parts);
198-
let deserialized_value = deserialize(message).unwrap();
209+
let message = serialize_bincode(&value).unwrap();
210+
assert_eq!(message.num_parts(), expected_parts);
211+
let deserialized_value = deserialize_bincode(message).unwrap();
199212
assert_eq!(value, deserialized_value);
200213

201214
// Test normal bincode passthrough:
@@ -206,13 +219,13 @@ mod tests {
206219

207220
#[test]
208221
fn test_specialized_serializer_basic() {
209-
test_roundtrip(Part::from("hello"), 2);
222+
test_roundtrip(Part::from("hello"), 1);
210223
}
211224

212225
#[test]
213226
fn test_specialized_serializer_compound() {
214-
test_roundtrip(vec![Part::from("hello"), Part::from("world")], 3);
215-
test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 3);
227+
test_roundtrip(vec![Part::from("hello"), Part::from("world")], 2);
228+
test_roundtrip((Part::from("hello"), 1, 2, 3, Part::from("world")), 2);
216229
test_roundtrip(
217230
{
218231
#[derive(Serialize, Deserialize, Debug, PartialEq)]
@@ -242,7 +255,7 @@ mod tests {
242255
],
243256
}
244257
},
245-
8,
258+
7,
246259
);
247260
test_roundtrip(
248261
{
@@ -262,29 +275,40 @@ mod tests {
262275
field5: 2,
263276
}
264277
},
265-
3,
278+
2,
266279
);
267280
}
268281

282+
#[test]
283+
fn test_recursive_message() {
284+
let message = serialize_bincode(&[Part::from("hello"), Part::from("world")]).unwrap();
285+
let message_message = serialize_bincode(&message).unwrap();
286+
287+
// message.body + message.parts (x2):
288+
assert_eq!(message_message.num_parts(), 3);
289+
}
290+
269291
#[test]
270292
fn test_malformed_messages() {
271293
let message = Message {
272-
body: Bytes::from_static(b"hello"),
294+
body: Part::from("hello"),
273295
parts: vec![Part::from("world")],
274296
};
275-
let err = deserialize::<String>(message).unwrap_err();
297+
let err = deserialize_bincode::<String>(message).unwrap_err();
276298

277299
// Normal bincode errors work:
278300
assert_matches!(*err, bincode::ErrorKind::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof);
279301

280-
let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap();
302+
let mut message =
303+
serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
281304
message.parts.push(Part::from("foo"));
282-
let err = deserialize::<Vec<Part>>(message).unwrap_err();
305+
let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
283306
assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart overrun while decoding");
284307

285-
let mut message = serialize(&vec![Part::from("hello"), Part::from("world")]).unwrap();
308+
let mut message =
309+
serialize_bincode(&vec![Part::from("hello"), Part::from("world")]).unwrap();
286310
let _dropped_message = message.parts.pop().unwrap();
287-
let err = deserialize::<Vec<Part>>(message).unwrap_err();
311+
let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
288312
assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding");
289313
}
290314
}

serde_multipart/src/part.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use std::ops::Deref;
10+
911
use bytes::Bytes;
1012
use bytes::buf::Reader as BufReader;
1113
use bytes::buf::Writer as BufWriter;
@@ -22,14 +24,29 @@ use crate::ser;
2224
/// serialization implementation that is specialized for the multipart codecs in
2325
/// this crate, skipping copying the bytes whenever possible.
2426
#[derive(Clone, Debug, PartialEq, Eq)]
25-
pub struct Part(Bytes);
27+
pub struct Part(pub(crate) Bytes);
28+
29+
impl Part {
30+
/// Consumes the part, returning its underlying byte buffer.
31+
pub fn into_inner(self) -> Bytes {
32+
self.0
33+
}
34+
}
2635

2736
impl<T: Into<Bytes>> From<T> for Part {
2837
fn from(bytes: T) -> Self {
2938
Self(bytes.into())
3039
}
3140
}
3241

42+
impl Deref for Part {
43+
type Target = Bytes;
44+
45+
fn deref(&self) -> &Self::Target {
46+
&self.0
47+
}
48+
}
49+
3350
impl Serialize for Part {
3451
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
3552
<Part as PartSerializer<S>>::serialize(self, s)

0 commit comments

Comments
 (0)