@@ -4,12 +4,13 @@ use std::str::FromStr;
44
55use async_std:: io:: { BufReader , Read , Write } ;
66use async_std:: prelude:: * ;
7- use http_types:: headers:: { CONTENT_LENGTH , EXPECT , HOST , TRANSFER_ENCODING } ;
7+ use http_types:: headers:: { CONTENT_LENGTH , EXPECT , TRANSFER_ENCODING } ;
88use http_types:: { ensure, ensure_eq, format_err} ;
99use http_types:: { Body , Method , Request } ;
1010
1111use crate :: chunked:: ChunkedDecoder ;
1212use crate :: { MAX_HEADERS , MAX_HEAD_LENGTH } ;
13+ use url:: Url ;
1314
1415const LF : u8 = b'\n' ;
1516
5657 let method = httparse_req. method ;
5758 let method = method. ok_or_else ( || format_err ! ( "No method found" ) ) ?;
5859
59- let path = httparse_req. path ;
60- let path = path. ok_or_else ( || format_err ! ( "No uri found" ) ) ?;
61-
6260 let version = httparse_req. version ;
6361 let version = version. ok_or_else ( || format_err ! ( "No version found" ) ) ?;
6462
@@ -69,16 +67,14 @@ where
6967 version
7068 ) ;
7169
72- let mut req = Request :: new (
73- Method :: from_str ( method) ?,
74- url:: Url :: parse ( "http://_" ) . unwrap ( ) . join ( path) ?,
75- ) ;
70+ let url = url_from_httparse_req ( & httparse_req) ?;
71+
72+ let mut req = Request :: new ( Method :: from_str ( method) ?, url) ;
7673
7774 for header in httparse_req. headers . iter ( ) {
7875 req. insert_header ( header. name , std:: str:: from_utf8 ( header. value ) ?) ;
7976 }
8077
81- set_url_and_port_from_host_header ( & mut req) ?;
8278 handle_100_continue ( & req, & mut io) . await ?;
8379
8480 let content_length = req. header ( CONTENT_LENGTH ) ;
@@ -109,23 +105,27 @@ where
109105 Ok ( Some ( req) )
110106}
111107
112- fn set_url_and_port_from_host_header ( req : & mut Request ) -> http_types:: Result < ( ) > {
108+ fn url_from_httparse_req ( req : & httparse:: Request < ' _ , ' _ > ) -> http_types:: Result < Url > {
109+ let path = req. path . ok_or_else ( || format_err ! ( "No uri found" ) ) ?;
113110 let host = req
114- . header ( HOST )
115- . map ( |header| header. last ( ) ) // There must only exactly one Host header, so this is permissive
116- . ok_or_else ( || format_err ! ( "Mandatory Host header missing" ) ) ? // https://tools.ietf.org/html/rfc7230#section-5.4
117- . to_string ( ) ;
118-
119- if let Some ( colon) = host. find ( ":" ) {
120- req. url_mut ( ) . set_host ( Some ( & host[ 0 ..colon] ) ) ?;
121- req. url_mut ( )
122- . set_port ( host[ colon + 1 ..] . parse ( ) . ok ( ) )
123- . unwrap ( ) ;
111+ . headers
112+ . iter ( )
113+ . filter ( |x| x. name . eq_ignore_ascii_case ( "host" ) )
114+ . next ( )
115+ . ok_or_else ( || format_err ! ( "Mandatory Host header missing" ) ) ?
116+ . value ;
117+
118+ let host = std:: str:: from_utf8 ( host) ?;
119+
120+ if path. starts_with ( "http://" ) || path. starts_with ( "https://" ) {
121+ Ok ( Url :: parse ( path) ?)
122+ } else if path. starts_with ( "/" ) {
123+ Ok ( Url :: parse ( & format ! ( "http://{}/" , host) ) ?. join ( path) ?)
124+ } else if req. method . unwrap ( ) . eq_ignore_ascii_case ( "connect" ) {
125+ Ok ( Url :: parse ( & format ! ( "http://{}/" , path) ) ?)
124126 } else {
125- req . url_mut ( ) . set_host ( Some ( & host ) ) ? ;
127+ Err ( format_err ! ( "unexpected uri format" ) )
126128 }
127-
128- Ok ( ( ) )
129129}
130130
131131const EXPECT_HEADER_VALUE : & str = "100-continue" ;
@@ -146,103 +146,96 @@ where
146146mod tests {
147147 use super :: * ;
148148
149- #[ test]
150- fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
151- let request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
152- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
153- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
154- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
155- assert ! ( result. is_ok( ) ) ;
149+ fn httparse_req ( buf : & str , f : impl Fn ( httparse:: Request < ' _ , ' _ > ) ) {
150+ let mut headers = [ httparse:: EMPTY_HEADER ; MAX_HEADERS ] ;
151+ let mut res = httparse:: Request :: new ( & mut headers[ ..] ) ;
152+ res. parse ( buf. as_bytes ( ) ) . unwrap ( ) ;
153+ f ( res)
156154 }
157155
158156 #[ test]
159- fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
160- let mut request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
161- request. append_header ( "expect" , "100-continue" ) ;
162- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
163- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
164- assert_eq ! (
165- std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) ,
166- "HTTP/1.1 100 Continue\r \n "
157+ fn url_for_connect ( ) {
158+ httparse_req (
159+ "CONNECT server.example.com:443 HTTP/1.1\r \n Host: server.example.com:443\r \n " ,
160+ |req| {
161+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
162+ assert_eq ! ( url. as_str( ) , "http://server.example.com:443/" ) ;
163+ } ,
167164 ) ;
168- assert ! ( result. is_ok( ) ) ;
169165 }
170166
171167 #[ test]
172- fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
173- let mut request = Request :: new ( Method :: Get , url:: Url :: parse ( "x:" ) . unwrap ( ) ) ;
174- request. append_header ( "expect" , "110-extensions-not-allowed" ) ;
175- let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
176- let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
177- assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
178- assert ! ( result. is_ok( ) ) ;
168+ fn url_for_host_plus_path ( ) {
169+ httparse_req (
170+ "GET /some/resource HTTP/1.1\r \n Host: server.example.com:443\r \n " ,
171+ |req| {
172+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
173+ assert_eq ! ( url. as_str( ) , "http://server.example.com:443/some/resource" ) ;
174+ } ,
175+ )
179176 }
180177
181178 #[ test]
182- fn test_setting_host_with_no_port ( ) {
183- let mut request = request_with_host_header ( "subdomain.mydomain.tld" ) ;
184- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
185- assert_eq ! (
186- request. url( ) ,
187- & url:: Url :: parse( "http://subdomain.mydomain.tld/some/path" ) . unwrap( )
188- ) ;
179+ fn url_for_host_plus_absolute_url ( ) {
180+ httparse_req (
181+ "GET http://domain.com/some/resource HTTP/1.1\r \n Host: server.example.com\r \n " ,
182+ |req| {
183+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
184+ assert_eq ! ( url. as_str( ) , "http://domain.com/some/resource" ) ; // host header MUST be ignored according to spec
185+ } ,
186+ )
189187 }
190188
191189 #[ test]
192- fn test_setting_host_with_a_port ( ) {
193- let mut request = request_with_host_header ( "subdomain.mydomain.tld:8080" ) ;
194- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
195- assert_eq ! (
196- request. url( ) ,
197- & url:: Url :: parse( "http://subdomain.mydomain.tld:8080/some/path" ) . unwrap( )
198- ) ;
190+ fn url_for_conflicting_connect ( ) {
191+ httparse_req (
192+ "CONNECT server.example.com:443 HTTP/1.1\r \n Host: conflicting.host\r \n " ,
193+ |req| {
194+ let url = url_from_httparse_req ( & req) . unwrap ( ) ;
195+ assert_eq ! ( url. as_str( ) , "http://server.example.com:443/" ) ;
196+ } ,
197+ )
199198 }
200199
201200 #[ test]
202- fn test_setting_host_with_an_ip_and_port ( ) {
203- let mut request = request_with_host_header ( "12.34.56.78:90" ) ;
204- set_url_and_port_from_host_header ( & mut request ) . unwrap ( ) ;
205- assert_eq ! (
206- request . url ( ) ,
207- & url :: Url :: parse ( "http://12.34.56.78:90/some/path" ) . unwrap ( )
208- ) ;
201+ fn url_for_malformed_resource_path ( ) {
202+ httparse_req (
203+ "GET not-a-url HTTP/1.1 \r \n Host: server.example.com \r \n " ,
204+ |req| {
205+ assert ! ( url_from_httparse_req ( & req ) . is_err ( ) ) ;
206+ } ,
207+ )
209208 }
210209
211210 #[ test]
212- fn test_malformed_nonnumeric_port_is_ignored ( ) {
213- let mut request = request_with_host_header ( "hello.world:uh-oh" ) ;
214- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
215- assert_eq ! (
216- request. url( ) ,
217- & url:: Url :: parse( "http://hello.world/some/path" ) . unwrap( )
218- ) ;
211+ fn handle_100_continue_does_nothing_with_no_expect_header ( ) {
212+ let request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
213+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
214+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
215+ assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
216+ assert ! ( result. is_ok( ) ) ;
219217 }
220218
221219 #[ test]
222- fn test_malformed_trailing_colon_is_ignored ( ) {
223- let mut request = request_with_host_header ( "edge.cases:" ) ;
224- set_url_and_port_from_host_header ( & mut request) . unwrap ( ) ;
220+ fn handle_100_continue_sends_header_if_expects_is_exactly_right ( ) {
221+ let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
222+ request. append_header ( "expect" , "100-continue" ) ;
223+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
224+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
225225 assert_eq ! (
226- request . url ( ) ,
227- & url :: Url :: parse ( "http://edge.cases/some/path" ) . unwrap ( )
226+ std :: str :: from_utf8 ( & io . into_inner ( ) ) . unwrap ( ) ,
227+ "HTTP/1.1 100 Continue \r \n "
228228 ) ;
229+ assert ! ( result. is_ok( ) ) ;
229230 }
230231
231232 #[ test]
232- fn test_malformed_leading_colon_is_invalid_host_value ( ) {
233- let mut request = request_with_host_header ( ":300" ) ;
234- assert ! ( set_url_and_port_from_host_header( & mut request) . is_err( ) ) ;
235- }
236-
237- #[ test]
238- fn test_malformed_invalid_url_host_is_invalid_host_header_value ( ) {
239- let mut request = request_with_host_header ( " " ) ;
240- assert ! ( set_url_and_port_from_host_header( & mut request) . is_err( ) ) ;
241- }
242-
243- fn request_with_host_header ( host : & str ) -> Request {
244- let mut req = Request :: new ( Method :: Get , url:: Url :: parse ( "http://_/some/path" ) . unwrap ( ) ) ;
245- req. insert_header ( HOST , host) ;
246- req
233+ fn handle_100_continue_does_nothing_if_expects_header_is_wrong ( ) {
234+ let mut request = Request :: new ( Method :: Get , Url :: parse ( "x:" ) . unwrap ( ) ) ;
235+ request. append_header ( "expect" , "110-extensions-not-allowed" ) ;
236+ let mut io = async_std:: io:: Cursor :: new ( vec ! [ ] ) ;
237+ let result = async_std:: task:: block_on ( handle_100_continue ( & request, & mut io) ) ;
238+ assert_eq ! ( std:: str :: from_utf8( & io. into_inner( ) ) . unwrap( ) , "" ) ;
239+ assert ! ( result. is_ok( ) ) ;
247240 }
248241}
0 commit comments