1
1
//! Process HTTP connections on the server.
2
2
3
3
use async_std:: future:: { timeout, Future , TimeoutError } ;
4
- use async_std:: io:: { self , BufReader } ;
4
+ use async_std:: io:: { self , BufRead , BufReader } ;
5
5
use async_std:: io:: { Read , Write } ;
6
6
use async_std:: prelude:: * ;
7
7
use async_std:: task:: { Context , Poll } ;
8
8
use futures_core:: ready;
9
- use http_types:: { HttpVersion , Method , Request , Response } ;
10
- use std:: convert:: TryFrom ;
9
+ use http_types:: { Method , Request , Response } ;
10
+ use std:: fmt;
11
+ use std:: str:: FromStr ;
11
12
use std:: time:: Duration ;
12
13
13
14
use std:: pin:: Pin ;
@@ -28,25 +29,43 @@ where
28
29
// TODO: make configurable
29
30
let timeout_duration = Duration :: from_secs ( 10 ) ;
30
31
const MAX_REQUESTS : usize = 200 ;
31
-
32
- let req = decode ( reader) . await ?;
33
32
let mut num_requests = 0 ;
34
- if let Some ( mut req) = req {
33
+
34
+ // Decode a request. This may be the first of many since
35
+ // the connection is Keep-Alive by default
36
+ let decoded = decode ( reader) . await ?;
37
+ // Decode returns one of three things;
38
+ // * A request with its body reader set to the underlying TCP stream
39
+ // * A request with an empty body AND the underlying stream
40
+ // * No request (because of the stream closed) and no underlying stream
41
+ if let Some ( mut decoded) = decoded {
35
42
loop {
36
43
num_requests += 1 ;
37
44
if num_requests > MAX_REQUESTS {
45
+ // We've exceeded the max number of requests per connection
38
46
return Ok ( ( ) ) ;
39
47
}
40
48
49
+ // Pass the request to the user defined request handler callback.
50
+ // Encode the response we get back.
41
51
// TODO: what to do when the callback returns Err
42
- let mut res = encode ( callback ( & mut req) . await ?) . await ?;
43
- let stream = req. into_body ( ) ;
52
+ let mut res = encode ( callback ( decoded. mut_request ( ) ) . await ?) . await ?;
53
+
54
+ // If we have reference to the stream, unwrap it. Otherwise,
55
+ // get the underlying stream from the request
56
+ let to_decode = decoded. into_reader ( ) ;
57
+
58
+ // Copy the response into the writer
44
59
io:: copy ( & mut res, & mut writer) . await ?;
45
- req = match timeout ( timeout_duration, decode ( stream) ) . await {
60
+
61
+ // Decode a new request, timing out if this takes longer than the
62
+ // timeout duration.
63
+ decoded = match timeout ( timeout_duration, decode ( to_decode) ) . await {
46
64
Ok ( Ok ( Some ( r) ) ) => r,
47
65
Ok ( Ok ( None ) ) | Err ( TimeoutError { .. } ) => break , /* EOF or timeout */
48
66
Ok ( Err ( e) ) => return Err ( e) ,
49
67
} ;
68
+ // Loop back with the new request and stream and start again
50
69
}
51
70
}
52
71
@@ -147,8 +166,11 @@ pub async fn encode(res: Response) -> io::Result<Encoder> {
147
166
Ok ( Encoder :: new ( buf, res) )
148
167
}
149
168
169
+ /// The number returned from httparse when the request is HTTP 1.1
170
+ const HTTP_1_1_VERSION : u8 = 1 ;
171
+
150
172
/// Decode an HTTP request on the server.
151
- pub async fn decode < R > ( reader : R ) -> Result < Option < Request > , Exception >
173
+ pub async fn decode < R > ( reader : R ) -> Result < Option < DecodedRequest > , Exception >
152
174
where
153
175
R : Read + Unpin + Send + ' static ,
154
176
{
@@ -184,11 +206,10 @@ where
184
206
let uri = httparse_req. path . ok_or_else ( || "No uri found" ) ?;
185
207
let uri = url:: Url :: parse ( uri) ?;
186
208
let version = httparse_req. version . ok_or_else ( || "No version found" ) ?;
187
- let version = match version {
188
- 1 => HttpVersion :: HTTP1_1 ,
189
- _ => return Err ( "Unsupported HTTP version" . into ( ) ) ,
190
- } ;
191
- let mut req = Request :: new ( version, Method :: try_from ( method) ?, uri) ;
209
+ if version != HTTP_1_1_VERSION {
210
+ return Err ( "Unsupported HTTP version" . into ( ) ) ;
211
+ }
212
+ let mut req = Request :: new ( Method :: from_str ( method) ?, uri) ;
192
213
for header in httparse_req. headers . iter ( ) {
193
214
req = req. set_header ( header. name , std:: str:: from_utf8 ( header. value ) ?) ?;
194
215
}
@@ -203,14 +224,54 @@ where
203
224
. ok ( )
204
225
. and_then ( |s| s. parse :: < usize > ( ) . ok ( ) ) ;
205
226
206
- if let Some ( _len) = length {
207
- // TODO: set len
208
- req = req. set_body ( reader) ;
227
+ if let Some ( len) = length {
228
+ req = req. set_body_reader ( reader) ;
229
+ req = req. set_len ( len) ;
230
+
231
+ Ok ( Some ( DecodedRequest :: WithBody ( req) ) )
209
232
} else {
210
233
return Err ( "Invalid value for Content-Length" . into ( ) ) ;
211
234
}
212
- } ;
235
+ } else {
236
+ Ok ( Some ( DecodedRequest :: WithoutBody ( req, Box :: new ( reader) ) ) )
237
+ }
238
+ }
213
239
214
- // Return the request.
215
- Ok ( Some ( req) )
240
+ /// A decoded response
241
+ ///
242
+ /// Either a request with body stream OR a request without a
243
+ /// a body stream paired with the underlying stream
244
+ pub enum DecodedRequest {
245
+ WithBody ( Request ) ,
246
+ WithoutBody ( Request , Box < dyn BufRead + Unpin + Send + ' static > ) ,
247
+ }
248
+
249
+ impl fmt:: Debug for DecodedRequest {
250
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
251
+ match self {
252
+ DecodedRequest :: WithBody ( _) => write ! ( f, "WithBody" ) ,
253
+ DecodedRequest :: WithoutBody ( _, _) => write ! ( f, "WithoutBody" ) ,
254
+ }
255
+ }
256
+ }
257
+
258
+ impl DecodedRequest {
259
+ /// Get a mutable reference to the request
260
+ fn mut_request ( & mut self ) -> & mut Request {
261
+ match self {
262
+ DecodedRequest :: WithBody ( r) => r,
263
+ DecodedRequest :: WithoutBody ( r, _) => r,
264
+ }
265
+ }
266
+
267
+ /// Consume self and get access to the underlying reader
268
+ ///
269
+ /// When the request has a body, the underlying reader is the body.
270
+ /// When it does not, the underlying body has been passed alongside the request.
271
+ fn into_reader ( self ) -> Box < dyn BufRead + Unpin + Send + ' static > {
272
+ match self {
273
+ DecodedRequest :: WithBody ( r) => r. into_body_reader ( ) ,
274
+ DecodedRequest :: WithoutBody ( _, s) => s,
275
+ }
276
+ }
216
277
}
0 commit comments