@@ -93,25 +93,48 @@ impl Request {
93
93
}
94
94
95
95
/// Get the remote address for this request.
96
+ /// This is determined in the following priority:
97
+ /// 1. `Forwarded` header `for` key
98
+ /// 2. The first `X-Forwarded-For` header
99
+ /// 3. Peer address of the transport
96
100
pub fn remote ( & self ) -> Option < & str > {
97
- self . forwarded_for ( ) . or ( self . peer_addr ( ) )
101
+ self . forwarded_for ( ) . or_else ( || self . peer_addr ( ) )
102
+ }
103
+
104
+ /// Get the destination host for this request.
105
+ /// This is determined in the following priority:
106
+ /// 1. `Forwarded` header `host` key
107
+ /// 2. The first `X-Forwarded-Host` header
108
+ /// 3. `Host` header
109
+ /// 4. URL domain, if any
110
+ pub fn host ( & self ) -> Option < & str > {
111
+ self . forwarded_header_part ( "host" )
112
+ . or_else ( || {
113
+ self . header ( "X-Forwarded-Host" )
114
+ . and_then ( |h| h. as_str ( ) . split ( "," ) . next ( ) )
115
+ } )
116
+ . or_else ( || self . header ( & headers:: HOST ) . map ( |h| h. as_str ( ) ) )
117
+ . or_else ( || self . url ( ) . host_str ( ) )
98
118
}
99
119
100
- fn forwarded_for ( & self ) -> Option < & str > {
101
- if let Some ( header ) = self . header ( "Forwarded" ) {
120
+ fn forwarded_header_part ( & self , part : & str ) -> Option < & str > {
121
+ self . header ( "Forwarded" ) . and_then ( |header| {
102
122
header. as_str ( ) . split ( ";" ) . find_map ( |key_equals_value| {
103
123
let parts = key_equals_value. split ( "=" ) . collect :: < Vec < _ > > ( ) ;
104
- if parts. len ( ) == 2 && parts[ 0 ] . eq_ignore_ascii_case ( "for" ) {
124
+ if parts. len ( ) == 2 && parts[ 0 ] . eq_ignore_ascii_case ( part ) {
105
125
Some ( parts[ 1 ] )
106
126
} else {
107
127
None
108
128
}
109
129
} )
110
- } else if let Some ( header) = self . header ( "X-Forwarded-For" ) {
111
- header. as_str ( ) . split ( "," ) . next ( )
112
- } else {
113
- None
114
- }
130
+ } )
131
+ }
132
+
133
+ fn forwarded_for ( & self ) -> Option < & str > {
134
+ self . forwarded_header_part ( "for" ) . or_else ( || {
135
+ self . header ( "X-Forwarded-For" )
136
+ . and_then ( |header| header. as_str ( ) . split ( "," ) . next ( ) )
137
+ } )
115
138
}
116
139
117
140
/// Get the HTTP method
@@ -689,85 +712,143 @@ impl<'a> IntoIterator for &'a mut Request {
689
712
#[ cfg( test) ]
690
713
mod tests {
691
714
use super :: * ;
715
+ mod host {
716
+ use super :: * ;
717
+
718
+ #[ test]
719
+ fn when_forwarded_header_is_set ( ) {
720
+ let mut request = build_test_request ( ) ;
721
+ set_forwarded ( & mut request, "-" ) ;
722
+ set_x_forwarded_host ( & mut request, "this will not be used" ) ;
723
+ assert_eq ! ( request. forwarded_header_part( "host" ) , Some ( "host.com" ) ) ;
724
+ assert_eq ! ( request. host( ) , Some ( "host.com" ) ) ;
725
+ }
692
726
693
- fn build_test_request ( ) -> Request {
694
- let url = Url :: parse ( "http://irrelevant/" ) . unwrap ( ) ;
695
- Request :: new ( Method :: Get , url )
696
- }
727
+ # [ test ]
728
+ fn when_several_x_forwarded_hosts_exist ( ) {
729
+ let mut request = build_test_request ( ) ;
730
+ set_x_forwarded_host ( & mut request , "expected.host" ) ;
697
731
698
- fn set_x_forwarded_for ( request : & mut Request , client : & ' static str ) {
699
- request. insert_header (
700
- "x-forwarded-for" ,
701
- format ! ( "{},proxy.com,other-proxy.com" , client) ,
702
- ) ;
703
- }
732
+ assert_eq ! ( request. forwarded_header_part( "host" ) , None ) ;
733
+ assert_eq ! ( request. host( ) , Some ( "expected.host" ) ) ;
734
+ }
704
735
705
- fn set_forwarded ( request : & mut Request , client : & ' static str ) {
706
- request . insert_header (
707
- "Forwarded" ,
708
- format ! ( "by=something.com;for={}; host=host.com;proto=http ", client ) ,
709
- ) ;
710
- }
736
+ # [ test ]
737
+ fn when_only_one_x_forwarded_hosts_exist ( ) {
738
+ let mut request = build_test_request ( ) ;
739
+ request . insert_header ( "x-forwarded- host", "expected.host" ) ;
740
+ assert_eq ! ( request . host ( ) , Some ( "expected.host" ) ) ;
741
+ }
711
742
712
- #[ test]
713
- fn test_remote_and_forwarded_for_when_forwarded_is_properly_formatted ( ) {
714
- let mut request = build_test_request ( ) ;
715
- request. set_peer_addr ( Some ( "127.0.0.1:8000" ) ) ;
716
- set_forwarded ( & mut request, "127.0.0.1:8001" ) ;
743
+ #[ test]
744
+ fn when_host_header_is_set ( ) {
745
+ let mut request = build_test_request ( ) ;
746
+ request. insert_header ( "host" , "host.header" ) ;
747
+ assert_eq ! ( request. host( ) , Some ( "host.header" ) ) ;
748
+ }
717
749
718
- assert_eq ! ( request. forwarded_for( ) , Some ( "127.0.0.1:8001" ) ) ;
719
- assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8001" ) ) ;
750
+ #[ test]
751
+ fn when_there_are_no_headers ( ) {
752
+ let request = build_test_request ( ) ;
753
+ assert_eq ! ( request. host( ) , Some ( "async.rs" ) ) ;
754
+ }
755
+
756
+ #[ test]
757
+ fn when_url_has_no_domain ( ) {
758
+ let mut request = build_test_request ( ) ;
759
+ * request. url_mut ( ) = Url :: parse ( "x:" ) . unwrap ( ) ;
760
+ assert_eq ! ( request. host( ) , None ) ;
761
+ }
720
762
}
721
763
722
- #[ test]
723
- fn test_remote_and_forwarded_for_when_forwarded_is_improperly_formatted ( ) {
724
- let mut request = build_test_request ( ) ;
725
- request. set_peer_addr ( Some (
726
- "127.0.0.1:8000" . parse :: < std:: net:: SocketAddr > ( ) . unwrap ( ) ,
727
- ) ) ;
764
+ mod remote {
765
+ use super :: * ;
766
+ #[ test]
767
+ fn when_forwarded_is_properly_formatted ( ) {
768
+ let mut request = build_test_request ( ) ;
769
+ request. set_peer_addr ( Some ( "127.0.0.1:8000" ) ) ;
770
+ set_forwarded ( & mut request, "127.0.0.1:8001" ) ;
771
+
772
+ assert_eq ! ( request. forwarded_for( ) , Some ( "127.0.0.1:8001" ) ) ;
773
+ assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8001" ) ) ;
774
+ }
728
775
729
- request. insert_header ( "Forwarded" , "this is an improperly ;;; formatted header" ) ;
776
+ #[ test]
777
+ fn when_forwarded_is_improperly_formatted ( ) {
778
+ let mut request = build_test_request ( ) ;
779
+ request. set_peer_addr ( Some (
780
+ "127.0.0.1:8000" . parse :: < std:: net:: SocketAddr > ( ) . unwrap ( ) ,
781
+ ) ) ;
730
782
731
- assert_eq ! ( request. forwarded_for( ) , None ) ;
732
- assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8000" ) ) ;
733
- }
783
+ request. insert_header ( "Forwarded" , "this is an improperly ;;; formatted header" ) ;
734
784
735
- #[ test]
736
- fn test_remote_and_forwarded_for_when_x_forwarded_for_is_set ( ) {
737
- let mut request = build_test_request ( ) ;
738
- request. set_peer_addr ( Some (
739
- std:: path:: PathBuf :: from ( "/dev/random" ) . to_str ( ) . unwrap ( ) ,
740
- ) ) ;
741
- set_x_forwarded_for ( & mut request, "forwarded-host.com" ) ;
785
+ assert_eq ! ( request. forwarded_for( ) , None ) ;
786
+ assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8000" ) ) ;
787
+ }
742
788
743
- assert_eq ! ( request. forwarded_for( ) , Some ( "forwarded-host.com" ) ) ;
744
- assert_eq ! ( request. remote( ) , Some ( "forwarded-host.com" ) ) ;
745
- }
789
+ #[ test]
790
+ fn when_x_forwarded_for_is_set ( ) {
791
+ let mut request = build_test_request ( ) ;
792
+ request. set_peer_addr ( Some (
793
+ std:: path:: PathBuf :: from ( "/dev/random" ) . to_str ( ) . unwrap ( ) ,
794
+ ) ) ;
795
+ set_x_forwarded_for ( & mut request, "forwarded-host.com" ) ;
746
796
747
- #[ test]
748
- fn test_remote_and_forwarded_for_when_both_forwarding_headers_are_set ( ) {
749
- let mut request = build_test_request ( ) ;
750
- set_forwarded ( & mut request, "forwarded.com" ) ;
751
- set_x_forwarded_for ( & mut request, "forwarded-for-client.com" ) ;
752
- request. peer_addr = Some ( "127.0.0.1:8000" . into ( ) ) ;
797
+ assert_eq ! ( request. forwarded_for( ) , Some ( "forwarded-host.com" ) ) ;
798
+ assert_eq ! ( request. remote( ) , Some ( "forwarded-host.com" ) ) ;
799
+ }
800
+
801
+ #[ test]
802
+ fn when_both_forwarding_headers_are_set ( ) {
803
+ let mut request = build_test_request ( ) ;
804
+ set_forwarded ( & mut request, "forwarded.com" ) ;
805
+ set_x_forwarded_for ( & mut request, "forwarded-for-client.com" ) ;
806
+ request. peer_addr = Some ( "127.0.0.1:8000" . into ( ) ) ;
807
+
808
+ assert_eq ! ( request. forwarded_for( ) , Some ( "forwarded.com" . into( ) ) ) ;
809
+ assert_eq ! ( request. remote( ) , Some ( "forwarded.com" . into( ) ) ) ;
810
+ }
753
811
754
- assert_eq ! ( request. forwarded_for( ) , Some ( "forwarded.com" . into( ) ) ) ;
755
- assert_eq ! ( request. remote( ) , Some ( "forwarded.com" . into( ) ) ) ;
812
+ #[ test]
813
+ fn falling_back_to_peer_addr ( ) {
814
+ let mut request = build_test_request ( ) ;
815
+ request. peer_addr = Some ( "127.0.0.1:8000" . into ( ) ) ;
816
+
817
+ assert_eq ! ( request. forwarded_for( ) , None ) ;
818
+ assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8000" . into( ) ) ) ;
819
+ }
820
+
821
+ #[ test]
822
+ fn when_no_remote_available ( ) {
823
+ let request = build_test_request ( ) ;
824
+ assert_eq ! ( request. forwarded_for( ) , None ) ;
825
+ assert_eq ! ( request. remote( ) , None ) ;
826
+ }
756
827
}
757
828
758
- # [ test ]
759
- fn test_remote_falling_back_to_peer_addr ( ) {
760
- let mut request = build_test_request ( ) ;
761
- request . peer_addr = Some ( "127.0.0.1:8000" . into ( ) ) ;
829
+ fn build_test_request ( ) -> Request {
830
+ let url = Url :: parse ( "http://async.rs/" ) . unwrap ( ) ;
831
+ Request :: new ( Method :: Get , url )
832
+ }
762
833
763
- assert_eq ! ( request. forwarded_for( ) , None ) ;
764
- assert_eq ! ( request. remote( ) , Some ( "127.0.0.1:8000" . into( ) ) ) ;
834
+ fn set_x_forwarded_for ( request : & mut Request , client : & ' static str ) {
835
+ request. insert_header (
836
+ "x-forwarded-for" ,
837
+ format ! ( "{},proxy.com,other-proxy.com" , client) ,
838
+ ) ;
839
+ }
840
+
841
+ fn set_x_forwarded_host ( request : & mut Request , host : & ' static str ) {
842
+ request. insert_header (
843
+ "x-forwarded-host" ,
844
+ format ! ( "{},proxy.com,other-proxy.com" , host) ,
845
+ ) ;
765
846
}
766
847
767
- # [ test ]
768
- fn test_remote_and_forwarded_for_when_no_remote_available ( ) {
769
- let request = build_test_request ( ) ;
770
- assert_eq ! ( request . forwarded_for ( ) , None ) ;
771
- assert_eq ! ( request . remote ( ) , None ) ;
848
+ fn set_forwarded ( request : & mut Request , client : & ' static str ) {
849
+ request . insert_header (
850
+ "Forwarded" ,
851
+ format ! ( "by=something.com;for={};host=host.com;proto=http" , client ) ,
852
+ ) ;
772
853
}
773
854
}
0 commit comments