Skip to content

Commit cd5dc61

Browse files
authored
Merge pull request #18 from yoshuawuyts/server
Fix for server issues in `http-types` integration
2 parents 49ab4b1 + d2b0238 commit cd5dc61

File tree

1 file changed

+82
-21
lines changed

1 file changed

+82
-21
lines changed

src/server.rs

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
//! Process HTTP connections on the server.
22
33
use async_std::future::{timeout, Future, TimeoutError};
4-
use async_std::io::{self, BufReader};
4+
use async_std::io::{self, BufRead, BufReader};
55
use async_std::io::{Read, Write};
66
use async_std::prelude::*;
77
use async_std::task::{Context, Poll};
88
use futures_core::ready;
9-
use http_types::{HttpVersion, Method, Request, Response};
10-
use std::convert::TryFrom;
9+
use http_types::{Method, Request, Response};
10+
use std::fmt;
11+
use std::str::FromStr;
1112
use std::time::Duration;
1213

1314
use std::pin::Pin;
@@ -28,25 +29,43 @@ where
2829
// TODO: make configurable
2930
let timeout_duration = Duration::from_secs(10);
3031
const MAX_REQUESTS: usize = 200;
31-
32-
let req = decode(reader).await?;
3332
let mut num_requests = 0;
34-
if let Some(mut req) = req {
33+
34+
// Decode a request. This may be the first of many since
35+
// the connection is Keep-Alive by default
36+
let decoded = decode(reader).await?;
37+
// Decode returns one of three things;
38+
// * A request with its body reader set to the underlying TCP stream
39+
// * A request with an empty body AND the underlying stream
40+
// * No request (because of the stream closed) and no underlying stream
41+
if let Some(mut decoded) = decoded {
3542
loop {
3643
num_requests += 1;
3744
if num_requests > MAX_REQUESTS {
45+
// We've exceeded the max number of requests per connection
3846
return Ok(());
3947
}
4048

49+
// Pass the request to the user defined request handler callback.
50+
// Encode the response we get back.
4151
// TODO: what to do when the callback returns Err
42-
let mut res = encode(callback(&mut req).await?).await?;
43-
let stream = req.into_body();
52+
let mut res = encode(callback(decoded.mut_request()).await?).await?;
53+
54+
// If we have reference to the stream, unwrap it. Otherwise,
55+
// get the underlying stream from the request
56+
let to_decode = decoded.into_reader();
57+
58+
// Copy the response into the writer
4459
io::copy(&mut res, &mut writer).await?;
45-
req = match timeout(timeout_duration, decode(stream)).await {
60+
61+
// Decode a new request, timing out if this takes longer than the
62+
// timeout duration.
63+
decoded = match timeout(timeout_duration, decode(to_decode)).await {
4664
Ok(Ok(Some(r))) => r,
4765
Ok(Ok(None)) | Err(TimeoutError { .. }) => break, /* EOF or timeout */
4866
Ok(Err(e)) => return Err(e),
4967
};
68+
// Loop back with the new request and stream and start again
5069
}
5170
}
5271

@@ -147,8 +166,11 @@ pub async fn encode(res: Response) -> io::Result<Encoder> {
147166
Ok(Encoder::new(buf, res))
148167
}
149168

169+
/// The number returned from httparse when the request is HTTP 1.1
170+
const HTTP_1_1_VERSION: u8 = 1;
171+
150172
/// Decode an HTTP request on the server.
151-
pub async fn decode<R>(reader: R) -> Result<Option<Request>, Exception>
173+
pub async fn decode<R>(reader: R) -> Result<Option<DecodedRequest>, Exception>
152174
where
153175
R: Read + Unpin + Send + 'static,
154176
{
@@ -184,11 +206,10 @@ where
184206
let uri = httparse_req.path.ok_or_else(|| "No uri found")?;
185207
let uri = url::Url::parse(uri)?;
186208
let version = httparse_req.version.ok_or_else(|| "No version found")?;
187-
let version = match version {
188-
1 => HttpVersion::HTTP1_1,
189-
_ => return Err("Unsupported HTTP version".into()),
190-
};
191-
let mut req = Request::new(version, Method::try_from(method)?, uri);
209+
if version != HTTP_1_1_VERSION {
210+
return Err("Unsupported HTTP version".into());
211+
}
212+
let mut req = Request::new(Method::from_str(method)?, uri);
192213
for header in httparse_req.headers.iter() {
193214
req = req.set_header(header.name, std::str::from_utf8(header.value)?)?;
194215
}
@@ -203,14 +224,54 @@ where
203224
.ok()
204225
.and_then(|s| s.parse::<usize>().ok());
205226

206-
if let Some(_len) = length {
207-
// TODO: set len
208-
req = req.set_body(reader);
227+
if let Some(len) = length {
228+
req = req.set_body_reader(reader);
229+
req = req.set_len(len);
230+
231+
Ok(Some(DecodedRequest::WithBody(req)))
209232
} else {
210233
return Err("Invalid value for Content-Length".into());
211234
}
212-
};
235+
} else {
236+
Ok(Some(DecodedRequest::WithoutBody(req, Box::new(reader))))
237+
}
238+
}
213239

214-
// Return the request.
215-
Ok(Some(req))
240+
/// A decoded response
241+
///
242+
/// Either a request with body stream OR a request without a
243+
/// a body stream paired with the underlying stream
244+
pub enum DecodedRequest {
245+
WithBody(Request),
246+
WithoutBody(Request, Box<dyn BufRead + Unpin + Send + 'static>),
247+
}
248+
249+
impl fmt::Debug for DecodedRequest {
250+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251+
match self {
252+
DecodedRequest::WithBody(_) => write!(f, "WithBody"),
253+
DecodedRequest::WithoutBody(_, _) => write!(f, "WithoutBody"),
254+
}
255+
}
256+
}
257+
258+
impl DecodedRequest {
259+
/// Get a mutable reference to the request
260+
fn mut_request(&mut self) -> &mut Request {
261+
match self {
262+
DecodedRequest::WithBody(r) => r,
263+
DecodedRequest::WithoutBody(r, _) => r,
264+
}
265+
}
266+
267+
/// Consume self and get access to the underlying reader
268+
///
269+
/// When the request has a body, the underlying reader is the body.
270+
/// When it does not, the underlying body has been passed alongside the request.
271+
fn into_reader(self) -> Box<dyn BufRead + Unpin + Send + 'static> {
272+
match self {
273+
DecodedRequest::WithBody(r) => r.into_body_reader(),
274+
DecodedRequest::WithoutBody(_, s) => s,
275+
}
276+
}
216277
}

0 commit comments

Comments
 (0)