@@ -84,34 +84,37 @@ where
84
84
/// This is returned from [`encode`].
85
85
#[ derive( Debug ) ]
86
86
struct Encoder {
87
- /// Keep track how far we've indexed into the headers + body.
88
- cursor : usize ,
89
87
/// HTTP headers to be sent.
90
- head : Option < Vec < u8 > > ,
91
- /// Check whether we're done sending headers.
92
- head_done : bool ,
93
- /// Response containing the HTTP body to be sent.
94
88
res : Response ,
95
- /// Check whether we're done with the body.
96
- body_done : bool ,
97
- /// Keep track of how many bytes have been read from the body stream.
98
- body_bytes_read : usize ,
89
+ /// The state of the encoding process
90
+ state : EncoderState ,
91
+ }
92
+
93
+ #[ derive( Debug ) ]
94
+ enum EncoderState {
95
+ Start ,
96
+ Head {
97
+ data : Vec < u8 > ,
98
+ head_bytes_read : usize ,
99
+ } ,
100
+ Body {
101
+ body_bytes_read : usize ,
102
+ body_len : usize ,
103
+ } ,
104
+ Chunked ,
105
+ Done ,
99
106
}
100
107
101
108
impl Encoder {
102
109
/// Create a new instance.
103
110
pub ( crate ) fn encode ( res : Response ) -> Self {
104
111
Self {
105
112
res,
106
- head : None ,
107
- cursor : 0 ,
108
- head_done : false ,
109
- body_done : false ,
110
- body_bytes_read : 0 ,
113
+ state : EncoderState :: Start ,
111
114
}
112
115
}
113
116
114
- fn encode_head ( & mut self ) -> io:: Result < ( ) > {
117
+ fn encode_head ( & self ) -> io:: Result < Vec < u8 > > {
115
118
let mut head: Vec < u8 > = vec ! [ ] ;
116
119
117
120
let reason = self . res . status ( ) . canonical_reason ( ) ;
@@ -138,8 +141,7 @@ impl Encoder {
138
141
139
142
std:: io:: Write :: write_fmt ( & mut head, format_args ! ( "\r \n " ) ) ?;
140
143
141
- self . head = Some ( head) ;
142
- Ok ( ( ) )
144
+ Ok ( head)
143
145
}
144
146
}
145
147
@@ -149,105 +151,154 @@ impl Read for Encoder {
149
151
cx : & mut Context < ' _ > ,
150
152
buf : & mut [ u8 ] ,
151
153
) -> Poll < io:: Result < usize > > {
152
- // Encode the headers to a buffer, the first time we poll
153
- if let None = self . head {
154
- self . encode_head ( ) ?;
155
- }
156
-
157
154
// we must keep track how many bytes of the head and body we've read
158
155
// in this call of `poll_read`
159
- let mut head_bytes_read = 0 ;
160
- let mut body_bytes_read = 0 ;
161
-
162
- // Read from the serialized headers, url and methods.
163
- if !self . head_done {
164
- let head = self . head . as_ref ( ) . unwrap ( ) ;
165
- let head_len = head. len ( ) ;
166
- let len = std:: cmp:: min ( head. len ( ) - self . cursor , buf. len ( ) ) ;
167
- let range = self . cursor ..self . cursor + len;
168
- buf[ 0 ..len] . copy_from_slice ( & head[ range] ) ;
169
- self . cursor += len;
170
- if self . cursor == head_len {
171
- self . head_done = true ;
172
- }
173
- head_bytes_read += len;
174
- }
175
-
176
- // Read from the AsyncRead impl on the inner Response struct only if
177
- // done reading from the head.
178
- // We must ensure there's space to write at least 2 bytes into the
179
- // response stream.
180
- if self . head_done && !self . body_done && head_bytes_read <= buf. len ( ) - 2 {
181
- // figure out how many bytes we can read. If a len was set, we need
182
- // to make sure we don't read more than that.
183
- let upper_bound = match self . res . len ( ) {
184
- Some ( len) => {
185
- debug_assert ! ( head_bytes_read == 0 || self . body_bytes_read == 0 ) ;
186
- ( head_bytes_read + len - self . body_bytes_read ) . min ( buf. len ( ) )
156
+ let mut bytes_read = 0 ;
157
+ loop {
158
+ println ! ( "{:?}" , self . state) ;
159
+ match self . state {
160
+ EncoderState :: Start => {
161
+ // Encode the headers to a buffer, the first time we poll
162
+ let head = self . encode_head ( ) ?;
163
+ self . state = EncoderState :: Head {
164
+ data : head,
165
+ head_bytes_read : 0 ,
166
+ } ;
187
167
}
188
- None => buf. len ( ) - 2 ,
189
- } ;
190
-
191
- match self . res . len ( ) {
192
- Some ( len) => {
168
+ EncoderState :: Head {
169
+ ref data,
170
+ mut head_bytes_read,
171
+ } => {
172
+ // Read from the serialized headers, url and methods.
173
+ let head_len = data. len ( ) ;
174
+ let len = std:: cmp:: min ( head_len - head_bytes_read, buf. len ( ) ) ;
175
+ let range = head_bytes_read..head_bytes_read + len;
176
+ buf[ 0 ..len] . copy_from_slice ( & data[ range] ) ;
177
+ bytes_read += len;
178
+ head_bytes_read += len;
179
+
180
+ // If we've read the total length of the head we're done
181
+ // reading the head and can transition to reading the body
182
+ if head_bytes_read == head_len {
183
+ // The response length lets us know if we are encoding
184
+ // our body in chunks or now
185
+ self . state = match self . res . len ( ) {
186
+ Some ( body_len) => EncoderState :: Body {
187
+ body_bytes_read : 0 ,
188
+ body_len,
189
+ } ,
190
+ None => EncoderState :: Chunked ,
191
+ } ;
192
+ }
193
+ }
194
+ EncoderState :: Body {
195
+ mut body_bytes_read,
196
+ body_len,
197
+ } => {
198
+ // Double check that we didn't somehow read more bytes than
199
+ // can fit in our buffer
200
+ debug_assert ! ( bytes_read <= buf. len( ) ) ;
201
+
202
+ // ensure we have at least room for 1 more byte in our buffer
203
+ if bytes_read == buf. len ( ) {
204
+ break ;
205
+ }
206
+ // figure out how many bytes we can read
207
+ let upper_bound = ( bytes_read + body_len - body_bytes_read) . min ( buf. len ( ) ) ;
193
208
// Read bytes, and update internal tracking stuff.
194
- body_bytes_read = ready ! ( Pin :: new( & mut self . res)
195
- . poll_read( cx, & mut buf[ head_bytes_read..upper_bound] ) ) ?;
196
-
209
+ let new_body_bytes_read =
210
+ ready ! ( Pin :: new( & mut self . res)
211
+ . poll_read( cx, & mut buf[ bytes_read..upper_bound] ) ) ?;
212
+ body_bytes_read += new_body_bytes_read;
213
+ bytes_read += new_body_bytes_read;
214
+
215
+ // Double check we did not read more body bytes than the total
216
+ // length of the body
197
217
debug_assert ! (
198
- self . body_bytes_read <= len ,
218
+ body_bytes_read <= body_len ,
199
219
"Too many bytes read. Expected: {}, read: {}" ,
200
- len ,
201
- self . body_bytes_read
220
+ body_len ,
221
+ body_bytes_read
202
222
) ;
203
223
// If we've read the `len` number of bytes or the stream no longer gives bytes, end.
204
- if len == self . body_bytes_read || body_bytes_read == 0 {
205
- self . body_done = true ;
206
- }
224
+ self . state = if body_len == body_bytes_read || body_bytes_read == 0 {
225
+ EncoderState :: Done
226
+ } else {
227
+ EncoderState :: Body {
228
+ body_bytes_read,
229
+ body_len,
230
+ }
231
+ } ;
207
232
}
208
- None => {
209
- let mut chunk_buf = vec ! [ 0 ; buf. len( ) ] ;
233
+ EncoderState :: Chunked => {
234
+ // ensure we have at least room for 1 more byte in our buffer
235
+ if bytes_read == buf. len ( ) {
236
+ break ;
237
+ }
238
+
239
+ // We can read a maximum of the buffer's total size
240
+ // minus what we've already filled the buffer with
241
+ let buffer_remaining = buf. len ( ) - bytes_read;
242
+ // we must allocate a separate buffer for the chunk data
243
+ // since we first need to know its length before writing
244
+ // it into the actual buffer
245
+ let mut chunk_buf = vec ! [ 0 ; buffer_remaining] ;
210
246
// Read bytes from body reader
211
- let chunk_length = ready ! ( Pin :: new ( & mut self . res )
212
- . poll_read( cx, & mut chunk_buf[ 0 ..buf . len ( ) - head_bytes_read ] ) ) ?;
247
+ let chunk_length =
248
+ ready ! ( Pin :: new ( & mut self . res ) . poll_read( cx, & mut chunk_buf) ) ?;
213
249
214
250
// serialize chunk length as hex
215
251
let chunk_length_string = format ! ( "{:X}" , chunk_length) ;
216
252
let chunk_length_bytes = chunk_length_string. as_bytes ( ) ;
217
253
let chunk_length_bytes_len = chunk_length_bytes. len ( ) ;
218
- body_bytes_read += chunk_length_bytes_len;
219
- buf[ head_bytes_read..head_bytes_read + body_bytes_read]
220
- . copy_from_slice ( chunk_length_bytes) ;
221
-
222
- // follow chunk length with CRLF
223
- buf[ head_bytes_read + body_bytes_read] = b'\r' ;
224
- buf[ head_bytes_read + body_bytes_read + 1 ] = b'\n' ;
225
- body_bytes_read += 2 ;
226
-
227
- // copy chunk into buf
228
- buf[ head_bytes_read + body_bytes_read
229
- ..head_bytes_read + body_bytes_read + chunk_length]
230
- . copy_from_slice ( & chunk_buf[ ..chunk_length] ) ;
231
- body_bytes_read += chunk_length;
232
-
233
- // TODO: relax this constraint at some point by adding extra state
234
- let bytes_read = head_bytes_read + body_bytes_read;
235
- assert ! ( buf. len( ) >= bytes_read + 7 , "Buffers should have room for the head, the chunk length, 2 bytes, the chunk, and 4 extra bytes when using chunked encoding" ) ;
254
+ const CRLF_LENGTH : usize = 2 ;
255
+
256
+ // calculate the total size of the chunk including serialized
257
+ // length and the CRLF padding
258
+ let total_chunk_size = bytes_read
259
+ + chunk_length_bytes_len
260
+ + CRLF_LENGTH
261
+ + chunk_length
262
+ + CRLF_LENGTH ;
263
+
264
+ // See if we can write the chunk out in one go
265
+ if total_chunk_size < buffer_remaining {
266
+ // Write the chunk length into the buffer
267
+ buf[ bytes_read..bytes_read + chunk_length_bytes_len]
268
+ . copy_from_slice ( chunk_length_bytes) ;
269
+ bytes_read += chunk_length_bytes_len;
270
+
271
+ // follow chunk length with CRLF
272
+ buf[ bytes_read] = b'\r' ;
273
+ buf[ bytes_read + 1 ] = b'\n' ;
274
+ bytes_read += 2 ;
275
+
276
+ // copy chunk into buf
277
+ buf[ bytes_read..bytes_read + chunk_length]
278
+ . copy_from_slice ( & chunk_buf[ ..chunk_length] ) ;
279
+ bytes_read += chunk_length;
280
+
281
+ // follow chunk with CRLF
282
+ buf[ bytes_read] = b'\r' ;
283
+ buf[ bytes_read + 1 ] = b'\n' ;
284
+ bytes_read += 2 ;
285
+ } else {
286
+ unimplemented ! ( "TODO: handle when buf isn't big enough" ) ;
287
+ }
288
+ const EMPTY_CHUNK : & [ u8 ; 5 ] = b"0\r \n \r \n " ;
236
289
237
- buf[ bytes_read..bytes_read + 7 ] . copy_from_slice ( b" \r \n 0 \r \n \r \n " ) ;
290
+ buf[ bytes_read..bytes_read + EMPTY_CHUNK . len ( ) ] . copy_from_slice ( EMPTY_CHUNK ) ;
238
291
239
292
// if body_bytes_read == 0 {
240
- self . body_done = true ;
241
- body_bytes_read += 7 ;
293
+ bytes_read += 7 ;
294
+ self . state = EncoderState :: Done ;
242
295
// }
243
296
}
297
+ EncoderState :: Done => break ,
244
298
}
245
299
}
300
+ println ! ( "{:?}" , std:: str :: from_utf8( & buf[ 0 ..bytes_read] ) ) ;
246
301
247
- self . body_bytes_read += body_bytes_read; // total body bytes read on all polls
248
-
249
- // Return the total amount of bytes read.
250
- let bytes_read = head_bytes_read + body_bytes_read;
251
302
Poll :: Ready ( Ok ( bytes_read as usize ) )
252
303
}
253
304
}
0 commit comments