2
2
3
3
use std:: str:: FromStr ;
4
4
5
- use async_std:: io:: BufReader ;
6
- use async_std:: io:: Read ;
5
+ use async_std:: io:: { BufReader , Read , Write } ;
7
6
use async_std:: prelude:: * ;
8
7
use http_types:: headers:: { HeaderName , HeaderValue , CONTENT_LENGTH , HOST , TRANSFER_ENCODING } ;
9
8
use http_types:: { ensure, ensure_eq, format_err} ;
@@ -18,11 +17,11 @@ const LF: u8 = b'\n';
18
17
const HTTP_1_1_VERSION : u8 = 1 ;
19
18
20
19
/// Decode an HTTP request on the server.
21
- pub ( crate ) async fn decode < R > ( reader : R ) -> http_types:: Result < Option < Request > >
20
+ pub ( crate ) async fn decode < IO > ( mut io : IO ) -> http_types:: Result < Option < Request > >
22
21
where
23
- R : Read + Unpin + Send + Sync + ' static ,
22
+ IO : Read + Write + Clone + Send + Sync + Unpin + ' static ,
24
23
{
25
- let mut reader = BufReader :: new ( reader ) ;
24
+ let mut reader = BufReader :: new ( io . clone ( ) ) ;
26
25
let mut buf = Vec :: new ( ) ;
27
26
let mut headers = [ httparse:: EMPTY_HEADER ; MAX_HEADERS ] ;
28
27
let mut httparse_req = httparse:: Request :: new ( & mut headers) ;
82
81
}
83
82
84
83
set_url_and_port_from_host_header ( & mut req) ?;
84
+ handle_100_continue ( & req, & mut io) . await ?;
85
85
86
86
let content_length = req. header ( & CONTENT_LENGTH ) ;
87
87
let transfer_encoding = req. header ( & TRANSFER_ENCODING ) ;
@@ -130,22 +130,58 @@ fn set_url_and_port_from_host_header(req: &mut Request) -> http_types::Result<()
130
130
Ok ( ( ) )
131
131
}
132
132
133
+ async fn handle_100_continue < IO : Write + Unpin > (
134
+ req : & Request ,
135
+ io : & mut IO ,
136
+ ) -> http_types:: Result < ( ) > {
137
+ let expect_header_value = req
138
+ . header ( & HeaderName :: from_str ( "expect" ) . unwrap ( ) )
139
+ . and_then ( |v| v. last ( ) )
140
+ . map ( |v| v. as_str ( ) ) ;
141
+
142
+ if let Some ( "100-continue" ) = expect_header_value {
143
+ io. write_all ( "HTTP/1.1 100 Continue\r \n " . as_bytes ( ) ) . await ?;
144
+ }
145
+
146
+ Ok ( ( ) )
147
+ }
148
+
133
149
#[ cfg( test) ]
134
150
mod tests {
135
151
use super :: * ;
136
152
137
- fn request_with_host_header ( host : & str ) -> Request {
138
- let mut req = Request :: new (
139
- Method :: from_str ( "GET ") . unwrap ( ) ,
140
- url :: Url :: parse ( "http://_" )
141
- . unwrap ( )
142
- . join ( "/some/path" )
143
- . unwrap ( ) ,
144
- ) ;
153
+ # [ test ]
154
+ fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
155
+ let request = Request :: new ( Method :: Get , url :: Url :: parse ( "x: ") . unwrap ( ) ) ;
156
+ let mut io = async_std :: io :: Cursor :: new ( vec ! [ ] ) ;
157
+ let result = async_std :: task :: block_on ( handle_100_continue ( & request , & mut io ) ) ;
158
+ assert_eq ! ( std :: str :: from_utf8 ( & io . into_inner ( ) ) . unwrap ( ) , "" ) ;
159
+ assert ! ( result . is_ok ( ) ) ;
160
+ }
145
161
146
- req. insert_header ( HOST , host) . unwrap ( ) ;
162
+ #[ test]
163
+ fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
164
+ let mut request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
165
+ request. append_header ( "expect" , "100-continue" ) . unwrap ( ) ;
166
+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
167
+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
168
+ assert_eq ! (
169
+ std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) ,
170
+ "HTTP/1.1 100 Continue\r \n "
171
+ ) ;
172
+ assert ! ( result. is_ok( ) ) ;
173
+ }
147
174
148
- req
175
+ #[ test]
176
+ fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
177
+ let mut request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
178
+ request
179
+ . append_header ( "expect" , "110-extensions-not-allowed" )
180
+ . unwrap ( ) ;
181
+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
182
+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
183
+ assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
184
+ assert ! ( result. is_ok( ) ) ;
149
185
}
150
186
151
187
#[ test]
@@ -209,4 +245,18 @@ mod tests {
209
245
let mut request = request_with_host_header ( " " ) ;
210
246
assert ! ( set_url_and_port_from_host_header( & mut request) . is_err( ) ) ;
211
247
}
248
+
249
+ fn request_with_host_header ( host : & str ) -> Request {
250
+ let mut req = Request :: new (
251
+ Method :: Get ,
252
+ url:: Url :: parse ( "http://_" )
253
+ . unwrap ( )
254
+ . join ( "/some/path" )
255
+ . unwrap ( ) ,
256
+ ) ;
257
+
258
+ req. insert_header ( HOST , host) . unwrap ( ) ;
259
+
260
+ req
261
+ }
212
262
}
0 commit comments