Skip to content

Commit 9405a59

Browse files
committed
[serde_multipart] framed multipart encoding
This change implements a framed multipart encoding for serde_multipart messages. This represents a common and convenient transport encoding for these messages. We implement this as a buffer with a specialized vectored IO implementation, allowing the transport to use vectored IO if available. Differential Revision: [D80262357](https://our.internmc.facebook.com/intern/diff/D80262357/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D80262357/)! ghstack-source-id: 303015121 Pull Request resolved: #876
1 parent 8f95049 commit 9405a59

File tree

2 files changed

+168
-2
lines changed

2 files changed

+168
-2
lines changed

serde_multipart/src/lib.rs

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@
2828
2929
#![feature(min_specialization)]
3030
#![feature(assert_matches)]
31+
#![feature(vec_deque_pop_if)]
3132

3233
use std::cell::UnsafeCell;
34+
use std::cmp::min;
35+
use std::collections::VecDeque;
36+
use std::io::IoSlice;
3337
use std::ptr::NonNull;
3438

3539
use bincode::Options;
@@ -47,6 +51,8 @@ use serde::Deserialize;
4751
use serde::Serialize;
4852

4953
/// A multi-part message, comprising a message body and a list of parts.
54+
/// Messages only contain references to underlying byte buffers and are
55+
/// cheaply cloned.
5056
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
5157
pub struct Message {
5258
body: Part,
@@ -89,6 +95,108 @@ impl Message {
8995
pub fn into_inner(self) -> (Part, Vec<Part>) {
9096
(self.body, self.parts)
9197
}
98+
99+
/// Efficiently frames a message containing the body and all of its parts
100+
/// using a simple frame-length encoding:
101+
///
102+
/// ```
103+
/// +--------------------+-------------------+--------------------+-------------------+ ... +
104+
/// | body_len (u64 BE) | body bytes | part1_len (u64 BE) | part1 bytes | |
105+
/// +--------------------+-------------------+--------------------+-------------------+ +
106+
/// repeat
107+
/// for
108+
/// each part
109+
/// ```
110+
pub fn framed(self) -> impl Buf {
111+
let (body, parts) = self.into_inner();
112+
let mut buffers = Vec::with_capacity(2 + 2 * parts.len());
113+
114+
let body = body.into_inner();
115+
buffers.push(Bytes::from_owner(body.len().to_be_bytes()));
116+
buffers.push(body);
117+
118+
for part in parts {
119+
let part = part.into_inner();
120+
buffers.push(Bytes::from_owner(part.len().to_be_bytes()));
121+
buffers.push(part);
122+
}
123+
124+
ConcatBuf::from_buffers(buffers)
125+
}
126+
127+
/// Reassembles a message from a framed encoding.
128+
pub fn from_framed(mut buf: Bytes) -> Result<Self, std::io::Error> {
129+
let body = Self::split_part(&mut buf)?.into();
130+
let mut parts = Vec::new();
131+
while buf.len() > 0 {
132+
parts.push(Self::split_part(&mut buf)?.into());
133+
}
134+
Ok(Self { body, parts })
135+
}
136+
137+
fn split_part(buf: &mut Bytes) -> Result<Bytes, std::io::Error> {
138+
if buf.len() < 8 {
139+
return Err(std::io::ErrorKind::UnexpectedEof.into());
140+
}
141+
let at = buf.get_u64() as usize;
142+
if buf.len() < at {
143+
return Err(std::io::ErrorKind::UnexpectedEof.into());
144+
}
145+
Ok(buf.split_to(at))
146+
}
147+
}
148+
149+
struct ConcatBuf {
150+
buffers: VecDeque<Bytes>,
151+
}
152+
153+
impl ConcatBuf {
154+
/// Construct a new concatenated buffer.
155+
fn from_buffers(buffers: Vec<Bytes>) -> Self {
156+
let mut buffers: VecDeque<Bytes> = buffers.into();
157+
buffers.retain(|buf| !buf.is_empty());
158+
Self { buffers }
159+
}
160+
}
161+
162+
impl Buf for ConcatBuf {
163+
fn remaining(&self) -> usize {
164+
self.buffers.iter().map(|buf| buf.remaining()).sum()
165+
}
166+
167+
fn chunk(&self) -> &[u8] {
168+
match self.buffers.front() {
169+
Some(buf) => buf.chunk(),
170+
None => &[],
171+
}
172+
}
173+
174+
fn advance(&mut self, mut cnt: usize) {
175+
while cnt > 0 {
176+
let Some(buf) = self.buffers.front_mut() else {
177+
panic!("advanced beyond the buffer size");
178+
};
179+
180+
if cnt >= buf.remaining() {
181+
cnt -= buf.remaining();
182+
self.buffers.pop_front();
183+
continue;
184+
}
185+
186+
buf.advance(cnt);
187+
cnt = 0;
188+
}
189+
}
190+
191+
// We implement our own chunks_vectored here, as the default implementation
192+
// does not do any vectoring (returning only a single IoSlice at a time).
193+
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
194+
let n = min(dst.len(), self.buffers.len());
195+
for i in 0..n {
196+
dst[i] = IoSlice::new(self.buffers[i].chunk());
197+
}
198+
n
199+
}
92200
}
93201

94202
/// An unsafe cell of a [`BytesMut`]. This is used to implement an io::Writer
@@ -206,12 +314,19 @@ mod tests {
206314
where
207315
T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug,
208316
{
317+
// Test plain serialization roundtrip:
209318
let message = serialize_bincode(&value).unwrap();
210319
assert_eq!(message.num_parts(), expected_parts);
211-
let deserialized_value = deserialize_bincode(message).unwrap();
320+
let deserialized_value = deserialize_bincode(message.clone()).unwrap();
212321
assert_eq!(value, deserialized_value);
213322

214-
// Test normal bincode passthrough:
323+
// Framing roundtrip:
324+
let mut framed = message.clone().framed();
325+
let framed = framed.copy_to_bytes(framed.remaining());
326+
let unframed_message = Message::from_framed(framed).unwrap();
327+
assert_eq!(message, unframed_message);
328+
329+
// Bincode passthrough:
215330
let bincode_serialized = bincode::serialize(&value).unwrap();
216331
let bincode_deserialized = bincode::deserialize(&bincode_serialized).unwrap();
217332
assert_eq!(value, bincode_deserialized);
@@ -311,4 +426,50 @@ mod tests {
311426
let err = deserialize_bincode::<Vec<Part>>(message).unwrap_err();
312427
assert_matches!(*err, bincode::ErrorKind::Custom(message) if message == "multipart underrun while decoding");
313428
}
429+
430+
#[test]
431+
fn test_concat_buf() {
432+
let buffers = vec![
433+
Bytes::from("hello"),
434+
Bytes::from("world"),
435+
Bytes::from("1"),
436+
Bytes::from(""),
437+
Bytes::from("xyz"),
438+
Bytes::from("xyzd"),
439+
];
440+
441+
let mut concat = ConcatBuf::from_buffers(buffers.clone());
442+
443+
assert_eq!(concat.remaining(), 18);
444+
concat.advance(2);
445+
assert_eq!(concat.remaining(), 16);
446+
assert_eq!(concat.chunk(), &b"llo"[..]);
447+
concat.advance(4);
448+
assert_eq!(concat.chunk(), &b"orld"[..]);
449+
concat.advance(5);
450+
assert_eq!(concat.chunk(), &b"xyz"[..]);
451+
452+
let mut concat = ConcatBuf::from_buffers(buffers);
453+
let bytes = concat.copy_to_bytes(concat.remaining());
454+
assert_eq!(&*bytes, &b"helloworld1xyzxyzd"[..]);
455+
}
456+
457+
#[test]
458+
fn test_framing() {
459+
let message = Message {
460+
body: Part::from("hello"),
461+
parts: vec![
462+
Part::from("world"),
463+
Part::from("1"),
464+
Part::from(""),
465+
Part::from("xyz"),
466+
Part::from("xyzd"),
467+
]
468+
.into(),
469+
};
470+
471+
let mut framed = message.clone().framed();
472+
let framed = framed.copy_to_bytes(framed.remaining());
473+
assert_eq!(Message::from_framed(framed).unwrap(), message);
474+
}
314475
}

serde_multipart/src/part.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ impl Part {
3131
pub fn into_inner(self) -> Bytes {
3232
self.0
3333
}
34+
35+
/// Returns a reference to the underlying byte buffer.
36+
pub fn to_bytes(&self) -> Bytes {
37+
self.0.clone()
38+
}
3439
}
3540

3641
impl<T: Into<Bytes>> From<T> for Part {

0 commit comments

Comments
 (0)