@@ -18,7 +18,7 @@ use std::{
1818 fmt:: { self , Debug , Formatter } ,
1919 future:: Future ,
2020 io:: ErrorKind ,
21- pin:: Pin ,
21+ pin:: { pin , Pin } ,
2222 str:: FromStr ,
2323 task:: { Context , Poll } ,
2424 time:: { Duration , Instant } ,
@@ -450,7 +450,7 @@ where
450450
451451 let state = this. state . project ( ) ;
452452 match state {
453- StateProj :: StreamClosed => return Poll :: Ready ( Some ( Err ( Error :: StreamClosed ) ) ) ,
453+ StateProj :: StreamClosed => return Poll :: Ready ( None ) ,
454454 // New immediately transitions to Connecting, and exists only
455455 // to ensure that we only connect when polled.
456456 StateProj :: New => {
@@ -517,30 +517,49 @@ where
517517 }
518518 }
519519
520- self . as_mut ( ) . reset_redirects ( ) ;
521- self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
522-
523- return Poll :: Ready ( Some ( Err ( Error :: UnexpectedResponse (
520+ let error = Error :: UnexpectedResponse (
524521 Response :: new ( resp. status ( ) , resp. headers ( ) . clone ( ) ) ,
525522 ErrorBody :: new ( resp. into_body ( ) ) ,
526- ) ) ) ) ;
523+ ) ;
524+
525+ if !* retry {
526+ self . as_mut ( ) . project ( ) . state . set ( State :: StreamClosed ) ;
527+ return Poll :: Ready ( Some ( Err ( error) ) ) ;
528+ }
529+
530+ self . as_mut ( ) . reset_redirects ( ) ;
531+
532+ let duration = self
533+ . as_mut ( )
534+ . project ( )
535+ . retry_strategy
536+ . next_delay ( Instant :: now ( ) ) ;
537+
538+ self . as_mut ( )
539+ . project ( )
540+ . state
541+ . set ( State :: WaitingToReconnect ( delay ( duration, "retrying" ) ) ) ;
542+
543+ return Poll :: Ready ( Some ( Err ( error) ) ) ;
527544 }
528545 Err ( e) => {
529546 // This happens when the server is unreachable, e.g. connection refused.
530547 warn ! ( "request returned an error: {}" , e) ;
531548 if !* retry {
532- self . as_mut ( ) . project ( ) . state . set ( State :: New ) ;
549+ self . as_mut ( ) . project ( ) . state . set ( State :: StreamClosed ) ;
533550 return Poll :: Ready ( Some ( Err ( Error :: HttpStream ( Box :: new ( e) ) ) ) ) ;
534551 }
552+
535553 let duration = self
536554 . as_mut ( )
537555 . project ( )
538556 . retry_strategy
539557 . next_delay ( Instant :: now ( ) ) ;
558+
540559 self . as_mut ( )
541560 . project ( )
542561 . state
543- . set ( State :: WaitingToReconnect ( delay ( duration, "retrying" ) ) )
562+ . set ( State :: WaitingToReconnect ( delay ( duration, "retrying" ) ) ) ;
544563 }
545564 } ,
546565 StateProj :: FollowingRedirect ( maybe_header) => match uri_from_header ( maybe_header) {
@@ -665,4 +684,84 @@ mod tests {
665684
666685 assert_eq ! ( Some ( & expected) , actual) ;
667686 }
687+
688+ use std:: { pin:: pin, str:: FromStr , time:: Duration } ;
689+
690+ use futures:: TryStreamExt ;
691+ use hyper:: { client:: HttpConnector , Body , HeaderMap , Request , Uri } ;
692+ use hyper_timeout:: TimeoutConnector ;
693+ use tokio:: time:: timeout;
694+
695+ use crate :: {
696+ client:: { RequestProps , State } ,
697+ ReconnectOptionsBuilder , ReconnectingRequest ,
698+ } ;
699+
700+ const INVALID_URI : & ' static str = "http://mycrazyunexsistenturl.invaliddomainext" ;
701+
702+ #[ test_case( INVALID_URI , false , |state| matches!( state, State :: StreamClosed ) ) ]
703+ #[ test_case( INVALID_URI , true , |state| matches!( state, State :: WaitingToReconnect ( _) ) ) ]
704+ #[ tokio:: test]
705+ async fn initial_connection ( uri : & str , retry_initial : bool , expected : fn ( & State ) -> bool ) {
706+ let default_timeout = Some ( Duration :: from_secs ( 1 ) ) ;
707+ let conn = HttpConnector :: new ( ) ;
708+ let mut connector = TimeoutConnector :: new ( conn) ;
709+ connector. set_connect_timeout ( default_timeout) ;
710+ connector. set_read_timeout ( default_timeout) ;
711+ connector. set_write_timeout ( default_timeout) ;
712+
713+ let reconnect_opts = ReconnectOptionsBuilder :: new ( false )
714+ . backoff_factor ( 1 )
715+ . delay ( Duration :: from_secs ( 1 ) )
716+ . retry_initial ( retry_initial)
717+ . build ( ) ;
718+
719+ let http = hyper:: Client :: builder ( ) . build :: < _ , hyper:: Body > ( connector) ;
720+ let req_props = RequestProps {
721+ url : Uri :: from_str ( uri) . unwrap ( ) ,
722+ headers : HeaderMap :: new ( ) ,
723+ method : "GET" . to_string ( ) ,
724+ body : None ,
725+ reconnect_opts,
726+ max_redirects : 10 ,
727+ } ;
728+
729+ let mut reconnecting_request = ReconnectingRequest :: new ( http, req_props, None ) ;
730+
731+ // sets initial state
732+ let resp = reconnecting_request. http . request (
733+ Request :: builder ( )
734+ . method ( "GET" )
735+ . uri ( uri)
736+ . body ( Body :: empty ( ) )
737+ . unwrap ( ) ,
738+ ) ;
739+
740+ reconnecting_request. state = State :: Connecting {
741+ retry : reconnecting_request. props . reconnect_opts . retry_initial ,
742+ resp,
743+ } ;
744+
745+ let mut reconnecting_request = pin ! ( reconnecting_request) ;
746+
747+ timeout ( Duration :: from_millis ( 500 ) , reconnecting_request. try_next ( ) )
748+ . await
749+ . ok ( ) ;
750+
751+ assert ! ( expected( & reconnecting_request. state) ) ;
752+ }
753+
754+ #[ test_case( false , |state| matches!( state, State :: StreamClosed ) ) ]
755+ #[ test_case( true , |state| matches!( state, State :: WaitingToReconnect ( _) ) ) ]
756+ #[ tokio:: test]
757+ async fn initial_connection_mocked_server ( retry_initial : bool , expected : fn ( & State ) -> bool ) {
758+ let mut mock_server = mockito:: Server :: new_async ( ) . await ;
759+ let _mock = mock_server
760+ . mock ( "GET" , "/" )
761+ . with_status ( 404 )
762+ . create_async ( )
763+ . await ;
764+
765+ initial_connection ( & mock_server. url ( ) , retry_initial, expected) . await ;
766+ }
668767}
0 commit comments