28
28
29
29
#![ feature( min_specialization) ]
30
30
#![ feature( assert_matches) ]
31
+ #![ feature( vec_deque_pop_if) ]
31
32
32
33
use std:: cell:: UnsafeCell ;
34
+ use std:: cmp:: min;
35
+ use std:: collections:: VecDeque ;
36
+ use std:: io:: IoSlice ;
33
37
use std:: ptr:: NonNull ;
34
38
35
39
use bincode:: Options ;
@@ -47,6 +51,8 @@ use serde::Deserialize;
47
51
use serde:: Serialize ;
48
52
49
53
/// A multi-part message, comprising a message body and a list of parts.
54
+ /// Messages only contain references to underlying byte buffers and are
55
+ /// cheaply cloned.
50
56
#[ derive( Debug , Clone , PartialEq , Eq , Serialize , Deserialize ) ]
51
57
pub struct Message {
52
58
body : Part ,
@@ -89,6 +95,108 @@ impl Message {
89
95
pub fn into_inner ( self ) -> ( Part , Vec < Part > ) {
90
96
( self . body , self . parts )
91
97
}
98
+
99
+ /// Efficiently frames a message containing the body and all of its parts
100
+ /// using a simple frame-length encoding:
101
+ ///
102
+ /// ```text
103
+ /// +--------------------+-------------------+--------------------+-------------------+ ... +
104
+ /// | body_len (u64 BE) | body bytes | part1_len (u64 BE) | part1 bytes | |
105
+ /// +--------------------+-------------------+--------------------+-------------------+ +
106
+ /// repeat
107
+ /// for
108
+ /// each part
109
+ /// ```
110
+ pub fn framed ( self ) -> impl Buf {
111
+ let ( body, parts) = self . into_inner ( ) ;
112
+ let mut buffers = Vec :: with_capacity ( 2 + 2 * parts. len ( ) ) ;
113
+
114
+ let body = body. into_inner ( ) ;
115
+ buffers. push ( Bytes :: from_owner ( body. len ( ) . to_be_bytes ( ) ) ) ;
116
+ buffers. push ( body) ;
117
+
118
+ for part in parts {
119
+ let part = part. into_inner ( ) ;
120
+ buffers. push ( Bytes :: from_owner ( part. len ( ) . to_be_bytes ( ) ) ) ;
121
+ buffers. push ( part) ;
122
+ }
123
+
124
+ ConcatBuf :: from_buffers ( buffers)
125
+ }
126
+
127
+ /// Reassembles a message from a framed encoding.
128
+ pub fn from_framed ( mut buf : Bytes ) -> Result < Self , std:: io:: Error > {
129
+ let body = Self :: split_part ( & mut buf) ?. into ( ) ;
130
+ let mut parts = Vec :: new ( ) ;
131
+ while buf. len ( ) > 0 {
132
+ parts. push ( Self :: split_part ( & mut buf) ?. into ( ) ) ;
133
+ }
134
+ Ok ( Self { body, parts } )
135
+ }
136
+
137
+ fn split_part ( buf : & mut Bytes ) -> Result < Bytes , std:: io:: Error > {
138
+ if buf. len ( ) < 8 {
139
+ return Err ( std:: io:: ErrorKind :: UnexpectedEof . into ( ) ) ;
140
+ }
141
+ let at = buf. get_u64 ( ) as usize ;
142
+ if buf. len ( ) < at {
143
+ return Err ( std:: io:: ErrorKind :: UnexpectedEof . into ( ) ) ;
144
+ }
145
+ Ok ( buf. split_to ( at) )
146
+ }
147
+ }
148
+
149
+ struct ConcatBuf {
150
+ buffers : VecDeque < Bytes > ,
151
+ }
152
+
153
+ impl ConcatBuf {
154
+ /// Construct a new concatenated buffer.
155
+ fn from_buffers ( buffers : Vec < Bytes > ) -> Self {
156
+ let mut buffers: VecDeque < Bytes > = buffers. into ( ) ;
157
+ buffers. retain ( |buf| !buf. is_empty ( ) ) ;
158
+ Self { buffers }
159
+ }
160
+ }
161
+
162
+ impl Buf for ConcatBuf {
163
+ fn remaining ( & self ) -> usize {
164
+ self . buffers . iter ( ) . map ( |buf| buf. remaining ( ) ) . sum ( )
165
+ }
166
+
167
+ fn chunk ( & self ) -> & [ u8 ] {
168
+ match self . buffers . front ( ) {
169
+ Some ( buf) => buf. chunk ( ) ,
170
+ None => & [ ] ,
171
+ }
172
+ }
173
+
174
+ fn advance ( & mut self , mut cnt : usize ) {
175
+ while cnt > 0 {
176
+ let Some ( buf) = self . buffers . front_mut ( ) else {
177
+ panic ! ( "advanced beyond the buffer size" ) ;
178
+ } ;
179
+
180
+ if cnt >= buf. remaining ( ) {
181
+ cnt -= buf. remaining ( ) ;
182
+ self . buffers . pop_front ( ) ;
183
+ continue ;
184
+ }
185
+
186
+ buf. advance ( cnt) ;
187
+ cnt = 0 ;
188
+ }
189
+ }
190
+
191
+ // We implement our own chunks_vectored here, as the default implementation
192
+ // does not do any vectoring (returning only a single IoSlice at a time).
193
+ fn chunks_vectored < ' a > ( & ' a self , dst : & mut [ IoSlice < ' a > ] ) -> usize {
194
+ let n = min ( dst. len ( ) , self . buffers . len ( ) ) ;
195
+ for i in 0 ..n {
196
+ dst[ i] = IoSlice :: new ( self . buffers [ i] . chunk ( ) ) ;
197
+ }
198
+ n
199
+ }
92
200
}
93
201
94
202
/// An unsafe cell of a [`BytesMut`]. This is used to implement an io::Writer
@@ -206,12 +314,19 @@ mod tests {
206
314
where
207
315
T : Serialize + DeserializeOwned + PartialEq + std:: fmt:: Debug ,
208
316
{
317
+ // Test plain serialization roundtrip:
209
318
let message = serialize_bincode ( & value) . unwrap ( ) ;
210
319
assert_eq ! ( message. num_parts( ) , expected_parts) ;
211
- let deserialized_value = deserialize_bincode ( message) . unwrap ( ) ;
320
+ let deserialized_value = deserialize_bincode ( message. clone ( ) ) . unwrap ( ) ;
212
321
assert_eq ! ( value, deserialized_value) ;
213
322
214
- // Test normal bincode passthrough:
323
+ // Framing roundtrip:
324
+ let mut framed = message. clone ( ) . framed ( ) ;
325
+ let framed = framed. copy_to_bytes ( framed. remaining ( ) ) ;
326
+ let unframed_message = Message :: from_framed ( framed) . unwrap ( ) ;
327
+ assert_eq ! ( message, unframed_message) ;
328
+
329
+ // Bincode passthrough:
215
330
let bincode_serialized = bincode:: serialize ( & value) . unwrap ( ) ;
216
331
let bincode_deserialized = bincode:: deserialize ( & bincode_serialized) . unwrap ( ) ;
217
332
assert_eq ! ( value, bincode_deserialized) ;
@@ -311,4 +426,50 @@ mod tests {
311
426
let err = deserialize_bincode :: < Vec < Part > > ( message) . unwrap_err ( ) ;
312
427
assert_matches ! ( * err, bincode:: ErrorKind :: Custom ( message) if message == "multipart underrun while decoding" ) ;
313
428
}
429
+
430
+ #[ test]
431
+ fn test_concat_buf ( ) {
432
+ let buffers = vec ! [
433
+ Bytes :: from( "hello" ) ,
434
+ Bytes :: from( "world" ) ,
435
+ Bytes :: from( "1" ) ,
436
+ Bytes :: from( "" ) ,
437
+ Bytes :: from( "xyz" ) ,
438
+ Bytes :: from( "xyzd" ) ,
439
+ ] ;
440
+
441
+ let mut concat = ConcatBuf :: from_buffers ( buffers. clone ( ) ) ;
442
+
443
+ assert_eq ! ( concat. remaining( ) , 18 ) ;
444
+ concat. advance ( 2 ) ;
445
+ assert_eq ! ( concat. remaining( ) , 16 ) ;
446
+ assert_eq ! ( concat. chunk( ) , & b"llo" [ ..] ) ;
447
+ concat. advance ( 4 ) ;
448
+ assert_eq ! ( concat. chunk( ) , & b"orld" [ ..] ) ;
449
+ concat. advance ( 5 ) ;
450
+ assert_eq ! ( concat. chunk( ) , & b"xyz" [ ..] ) ;
451
+
452
+ let mut concat = ConcatBuf :: from_buffers ( buffers) ;
453
+ let bytes = concat. copy_to_bytes ( concat. remaining ( ) ) ;
454
+ assert_eq ! ( & * bytes, & b"helloworld1xyzxyzd" [ ..] ) ;
455
+ }
456
+
457
+ #[ test]
458
+ fn test_framing ( ) {
459
+ let message = Message {
460
+ body : Part :: from ( "hello" ) ,
461
+ parts : vec ! [
462
+ Part :: from( "world" ) ,
463
+ Part :: from( "1" ) ,
464
+ Part :: from( "" ) ,
465
+ Part :: from( "xyz" ) ,
466
+ Part :: from( "xyzd" ) ,
467
+ ]
468
+ . into ( ) ,
469
+ } ;
470
+
471
+ let mut framed = message. clone ( ) . framed ( ) ;
472
+ let framed = framed. copy_to_bytes ( framed. remaining ( ) ) ;
473
+ assert_eq ! ( Message :: from_framed( framed) . unwrap( ) , message) ;
474
+ }
314
475
}
0 commit comments