@@ -4,9 +4,9 @@ use std::str::FromStr;
4
4
5
5
use async_std:: io:: { BufReader , Read , Write } ;
6
6
use async_std:: prelude:: * ;
7
- use http_types:: headers:: { CONTENT_LENGTH , EXPECT , HOST , TRANSFER_ENCODING } ;
7
+ use http_types:: headers:: { CONTENT_LENGTH , EXPECT , TRANSFER_ENCODING } ;
8
8
use http_types:: { ensure, ensure_eq, format_err} ;
9
- use http_types:: { Body , Method , Request } ;
9
+ use http_types:: { Body , Method , Request , Url } ;
10
10
11
11
use crate :: chunked:: ChunkedDecoder ;
12
12
use crate :: { MAX_HEADERS , MAX_HEAD_LENGTH } ;
56
56
let method = httparse_req. method ;
57
57
let method = method. ok_or_else ( || format_err ! ( "No method found" ) ) ?;
58
58
59
- let path = httparse_req. path ;
60
- let path = path. ok_or_else ( || format_err ! ( "No uri found" ) ) ?;
61
-
62
59
let version = httparse_req. version ;
63
60
let version = version. ok_or_else ( || format_err ! ( "No version found" ) ) ?;
64
61
@@ -69,16 +66,14 @@ where
69
66
version
70
67
) ;
71
68
72
- let mut req = Request :: new (
73
- Method :: from_str ( method) ?,
74
- url:: Url :: parse ( "http://_" ) . unwrap ( ) . join ( path) ?,
75
- ) ;
69
+ let url = url_from_httparse_req ( & httparse_req) ?;
70
+
71
+ let mut req = Request :: new ( Method :: from_str ( method) ?, url) ;
76
72
77
73
for header in httparse_req. headers . iter ( ) {
78
74
req. insert_header ( header. name , std:: str:: from_utf8 ( header. value ) ?) ;
79
75
}
80
76
81
- set_url_and_port_from_host_header ( & mut req) ?;
82
77
handle_100_continue ( & req, & mut io) . await ?;
83
78
84
79
let content_length = req. header ( CONTENT_LENGTH ) ;
@@ -109,29 +104,31 @@ where
109
104
Ok ( Some ( req) )
110
105
}
111
106
112
- fn set_url_and_port_from_host_header ( req : & mut Request ) -> http_types:: Result < ( ) > {
107
+ fn url_from_httparse_req ( req : & httparse:: Request < ' _ , ' _ > ) -> http_types:: Result < Url > {
108
+ let path = req. path . ok_or_else ( || format_err ! ( "No uri found" ) ) ?;
113
109
let host = req
114
- . header ( HOST )
115
- . map ( |header| header. last ( ) ) // There must only exactly one Host header, so this is permissive
116
- . ok_or_else ( || format_err ! ( "Mandatory Host header missing" ) ) ? // https://tools.ietf.org/html/rfc7230#section-5.4
117
- . to_string ( ) ;
118
-
119
- if !req. url ( ) . cannot_be_a_base ( ) {
120
- if let Some ( colon) = host. find ( ":" ) {
121
- req. url_mut ( ) . set_host ( Some ( & host[ 0 ..colon] ) ) ?;
122
- req. url_mut ( )
123
- . set_port ( host[ colon + 1 ..] . parse ( ) . ok ( ) )
124
- . unwrap ( ) ;
125
- } else {
126
- req. url_mut ( ) . set_host ( Some ( & host) ) ?;
127
- }
110
+ . headers
111
+ . iter ( )
112
+ . filter ( |x| x. name . eq_ignore_ascii_case ( "host" ) )
113
+ . next ( )
114
+ . ok_or_else ( || format_err ! ( "Mandatory Host header missing" ) ) ?
115
+ . value ;
116
+
117
+ let host = std:: str:: from_utf8 ( host) ?;
118
+
119
+ if path. starts_with ( "http://" ) || path. starts_with ( "https://" ) {
120
+ Ok ( Url :: parse ( path) ?)
121
+ } else if path. starts_with ( "/" ) {
122
+ Ok ( Url :: parse ( & format ! ( "http://{}/" , host) ) ?. join ( path) ?)
123
+ } else if req. method . unwrap ( ) . eq_ignore_ascii_case ( "connect" ) {
124
+ Ok ( Url :: parse ( & format ! ( "http://{}/" , path) ) ?)
125
+ } else {
126
+ Err ( format_err ! ( "unexpected uri format" ) )
128
127
}
129
-
130
- Ok ( ( ) )
131
128
}
132
129
133
130
const EXPECT_HEADER_VALUE : & str = "100-continue" ;
134
- const EXPECT_RESPONSE : & [ u8 ] = b"HTTP/1.1 100 Continue\r \n " ;
131
+ const EXPECT_RESPONSE : & [ u8 ] = b"HTTP/1.1 100 Continue\r \n \r \n " ;
135
132
136
133
async fn handle_100_continue < IO > ( req : & Request , io : & mut IO ) -> http_types:: Result < ( ) >
137
134
where
@@ -148,103 +145,96 @@ where
148
145
mod tests {
149
146
use super :: * ;
150
147
151
- #[ test]
152
- fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
153
- let request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
154
- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
155
- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
156
- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
157
- assert ! ( result. is_ok( ) ) ;
148
+ fn httparse_req ( buf : & str , f : impl Fn ( httparse:: Request < ' _ , ' _ > ) ) {
149
+ let mut headers = [ httparse:: EMPTY_HEADER ; MAX_HEADERS ] ;
150
+ let mut res = httparse:: Request :: new ( & mut headers[ ..] ) ;
151
+ res. parse ( buf. as_bytes ( ) ) . unwrap ( ) ;
152
+ f ( res)
158
153
}
159
154
160
155
#[ test]
161
- fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
162
- let mut request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
163
- request. append_header ( "expect" , "100-continue" ) ;
164
- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
165
- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
166
- assert_eq ! (
167
- std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) ,
168
- "HTTP/1.1 100 Continue\r \n "
156
+ fn url_for_connect ( ) {
157
+ httparse_req (
158
+ "CONNECT server.example.com:443 HTTP/1.1\r \n Host: server.example.com:443\r \n " ,
159
+ |req| {
160
+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
161
+ assert_eq ! ( url. as_str( ) , "http://server.example.com:443/" ) ;
162
+ } ,
169
163
) ;
170
- assert ! ( result. is_ok( ) ) ;
171
164
}
172
165
173
166
#[ test]
174
- fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
175
- let mut request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
176
- request. append_header ( "expect" , "110-extensions-not-allowed" ) ;
177
- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
178
- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
179
- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
180
- assert ! ( result. is_ok( ) ) ;
167
+ fn url_for_host_plus_path ( ) {
168
+ httparse_req (
169
+ "GET /some/resource HTTP/1.1\r \n Host: server.example.com:443\r \n " ,
170
+ |req| {
171
+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
172
+ assert_eq ! ( url. as_str( ) , "http://server.example.com:443/some/resource" ) ;
173
+ } ,
174
+ )
181
175
}
182
176
183
177
#[ test]
184
- fn test_setting_host_with_no_port ( ) {
185
- let mut request = request_with_host_header ( "subdomain.mydomain.tld" ) ;
186
- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
187
- assert_eq ! (
188
- request. url( ) ,
189
- & url:: Url :: parse( "http://subdomain.mydomain.tld/some/path" ) . unwrap( )
190
- ) ;
178
+ fn url_for_host_plus_absolute_url ( ) {
179
+ httparse_req (
180
+ "GET http://domain.com/some/resource HTTP/1.1\r \n Host: server.example.com\r \n " ,
181
+ |req| {
182
+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
183
+ assert_eq ! ( url. as_str( ) , "http://domain.com/some/resource" ) ; // host header MUST be ignored according to spec
184
+ } ,
185
+ )
191
186
}
192
187
193
188
#[ test]
194
- fn test_setting_host_with_a_port ( ) {
195
- let mut request = request_with_host_header ( "subdomain.mydomain.tld:8080" ) ;
196
- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
197
- assert_eq ! (
198
- request. url( ) ,
199
- & url:: Url :: parse( "http://subdomain.mydomain.tld:8080/some/path" ) . unwrap( )
200
- ) ;
189
+ fn url_for_conflicting_connect ( ) {
190
+ httparse_req (
191
+ "CONNECT server.example.com:443 HTTP/1.1\r \n Host: conflicting.host\r \n " ,
192
+ |req| {
193
+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
194
+ assert_eq ! ( url. as_str( ) , "http://server.example.com:443/" ) ;
195
+ } ,
196
+ )
201
197
}
202
198
203
199
#[ test]
204
- fn test_setting_host_with_an_ip_and_port ( ) {
205
- let mut request = request_with_host_header ( "12.34.56.78:90" ) ;
206
- set_url_and_port_from_host_header ( & mut request ) . unwrap ( ) ;
207
- assert_eq ! (
208
- request . url ( ) ,
209
- & url :: Url :: parse ( "http://12.34.56.78:90/some/path" ) . unwrap ( )
210
- ) ;
200
+ fn url_for_malformed_resource_path ( ) {
201
+ httparse_req (
202
+ "GET not-a-url HTTP/1.1 \r \n Host: server.example.com \r \n " ,
203
+ |req| {
204
+ assert ! ( url_from_httparse_req ( & req ) . is_err ( ) ) ;
205
+ } ,
206
+ )
211
207
}
212
208
213
209
#[ test]
214
- fn test_malformed_nonnumeric_port_is_ignored ( ) {
215
- let mut request = request_with_host_header ( "hello.world:uh-oh" ) ;
216
- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
217
- assert_eq ! (
218
- request. url( ) ,
219
- & url:: Url :: parse( "http://hello.world/some/path" ) . unwrap( )
220
- ) ;
210
+ fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
211
+ let request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
212
+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
213
+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
214
+ assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
215
+ assert ! ( result. is_ok( ) ) ;
221
216
}
222
217
223
218
#[ test]
224
- fn test_malformed_trailing_colon_is_ignored ( ) {
225
- let mut request = request_with_host_header ( "edge.cases:" ) ;
226
- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
219
+ fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
220
+ let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
221
+ request. append_header ( "expect" , "100-continue" ) ;
222
+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
223
+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
227
224
assert_eq ! (
228
- request . url ( ) ,
229
- & url :: Url :: parse ( "http://edge.cases/some/path" ) . unwrap ( )
225
+ std :: str :: from_utf8 ( & io . into_inner ( ) ) . unwrap ( ) ,
226
+ "HTTP/1.1 100 Continue \r \n \r \n "
230
227
) ;
228
+ assert ! ( result. is_ok( ) ) ;
231
229
}
232
230
233
231
#[ test]
234
- fn test_malformed_leading_colon_is_invalid_host_value ( ) {
235
- let mut request = request_with_host_header ( ":300" ) ;
236
- assert ! ( set_url_and_port_from_host_header( & mut request) . is_err( ) ) ;
237
- }
238
-
239
- #[ test]
240
- fn test_malformed_invalid_url_host_is_invalid_host_header_value ( ) {
241
- let mut request = request_with_host_header ( " " ) ;
242
- assert ! ( set_url_and_port_from_host_header( & mut request) . is_err( ) ) ;
243
- }
244
-
245
- fn request_with_host_header ( host : & str ) -> Request {
246
- let mut req = Request :: new ( Method :: Get , url:: Url :: parse ( "http://_/some/path" ) . unwrap ( ) ) ;
247
- req. insert_header ( HOST , host) ;
248
- req
232
+ fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
233
+ let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
234
+ request. append_header ( "expect" , "110-extensions-not-allowed" ) ;
235
+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
236
+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
237
+ assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
238
+ assert ! ( result. is_ok( ) ) ;
249
239
}
250
240
}
0 commit comments