@@ -37,16 +37,17 @@ fn get_ip_from_unique_header(headers: &HeaderMap, header_name: &str) -> Result<I
3737 let mut values = headers. get_all ( header_name) . iter ( ) ;
3838
3939 let first_value = values. next ( ) . ok_or ( IpError :: NotPresent ( header_name. to_string ( ) ) ) ?;
40+
41+ if values. next ( ) . is_some ( ) {
42+ return Err ( IpError :: NotUnique ( header_name. to_string ( ) ) ) ;
43+ }
44+
4045 let ip = first_value
4146 . to_str ( )
4247 . map_err ( |_| IpError :: HasInvalidCharacters ) ?
4348 . parse :: < IpAddr > ( )
4449 . map_err ( |_| IpError :: InvalidValue ) ?;
4550
46- if values. next ( ) . is_some ( ) {
47- return Err ( IpError :: NotUnique ( header_name. to_string ( ) ) ) ;
48- }
49-
5051 Ok ( ip)
5152}
5253
@@ -55,23 +56,187 @@ fn get_ip_from_rightmost_value(
5556 header_name : & str ,
5657 trusted_count : usize ,
5758) -> Result < IpAddr , IpError > {
58- let last_value = headers
59+ let joined_values = headers
5960 . get_all ( header_name)
6061 . iter ( )
61- . next_back ( )
62- . ok_or ( IpError :: NotPresent ( header_name. to_string ( ) ) ) ?
63- . to_str ( )
64- . map_err ( |_| IpError :: HasInvalidCharacters ) ?;
62+ . map ( |x| x. to_str ( ) . map_err ( |_| IpError :: HasInvalidCharacters ) )
63+ . collect :: < Result < Vec < & str > , IpError > > ( ) ?
64+ . join ( "," ) ;
65+
66+ if joined_values. is_empty ( ) {
67+ return Err ( IpError :: NotPresent ( header_name. to_string ( ) ) )
68+ }
6569
6670 // Selecting the first untrusted IP from the right according to:
6771 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/X-Forwarded-For#selecting_an_ip_address
68- last_value
72+ joined_values
6973 . rsplit ( "," )
7074 . nth ( trusted_count - 1 )
7175 . ok_or ( IpError :: NotEnoughValues {
72- found : last_value . split ( "," ) . count ( ) ,
76+ found : joined_values . split ( "," ) . count ( ) ,
7377 required : trusted_count,
7478 } ) ?
79+ . trim ( )
7580 . parse :: < IpAddr > ( )
7681 . map_err ( |_| IpError :: InvalidValue )
7782}
83+
84+ #[ cfg( test) ]
85+ mod tests {
86+ use std:: net:: Ipv4Addr ;
87+
88+ use super :: * ;
89+
90+ #[ test]
91+ fn test_unique_header_pass ( ) {
92+ let header_name = "X-Real-IP" ;
93+ let real_ip = IpAddr :: V4 ( Ipv4Addr :: new ( 1 , 1 , 1 , 1 ) ) ;
94+
95+ let mut headers = HeaderMap :: new ( ) ;
96+ headers. insert ( header_name, real_ip. to_string ( ) . parse ( ) . unwrap ( ) ) ;
97+
98+ let ip = get_ip_from_unique_header ( & headers, header_name) . unwrap ( ) ;
99+ assert_eq ! ( ip, real_ip) ;
100+ }
101+
102+ #[ test]
103+ fn test_unique_header_duplicated ( ) {
104+ let header_name = "X-Real-IP" ;
105+ let real_ip = IpAddr :: V4 ( Ipv4Addr :: new ( 1 , 1 , 1 , 1 ) ) ;
106+ let fake_ip = IpAddr :: V4 ( Ipv4Addr :: new ( 2 , 2 , 2 , 2 ) ) ;
107+
108+ let mut headers = HeaderMap :: new ( ) ;
109+ headers. insert ( header_name, real_ip. to_string ( ) . parse ( ) . unwrap ( ) ) ;
110+ headers. append ( header_name, fake_ip. to_string ( ) . parse ( ) . unwrap ( ) ) ;
111+
112+ let err = get_ip_from_unique_header ( & headers, header_name)
113+ . expect_err ( "Not unique header should fail" ) ;
114+ assert ! ( matches!( err, IpError :: NotUnique ( _) ) ) ;
115+ }
116+ #[ test]
117+ fn test_unique_header_not_present ( ) {
118+ let header_name = "X-Real-IP" ;
119+ let headers = HeaderMap :: new ( ) ;
120+
121+ let err = get_ip_from_unique_header ( & headers, header_name)
122+ . expect_err ( "Missing header should fail" ) ;
123+ assert ! ( matches!( err, IpError :: NotPresent ( _) ) ) ;
124+ }
125+
126+ #[ test]
127+ fn test_unique_header_invalid_value ( ) {
128+ let header_name = "X-Real-IP" ;
129+ let mut headers = HeaderMap :: new ( ) ;
130+ headers. insert ( header_name, "invalid-ip" . parse ( ) . unwrap ( ) ) ;
131+
132+ let err =
133+ get_ip_from_unique_header ( & headers, header_name) . expect_err ( "Invalid IP should fail" ) ;
134+ assert ! ( matches!( err, IpError :: InvalidValue ) ) ;
135+ }
136+
137+ #[ test]
138+ fn test_unique_header_empty_value ( ) {
139+ let header_name = "X-Real-IP" ;
140+ let mut headers = HeaderMap :: new ( ) ;
141+ headers. insert ( header_name, "" . parse ( ) . unwrap ( ) ) ;
142+
143+ let err =
144+ get_ip_from_unique_header ( & headers, header_name) . expect_err ( "Invalid IP should fail" ) ;
145+ assert ! ( matches!( err, IpError :: InvalidValue ) ) ;
146+ }
147+
148+ #[ test]
149+ fn test_rightmost_header_comma_separated ( ) {
150+ let header_name = "X-Forwarded-For" ;
151+ let ip1 = IpAddr :: V4 ( Ipv4Addr :: new ( 1 , 1 , 1 , 1 ) ) ;
152+ let ip2 = IpAddr :: V4 ( Ipv4Addr :: new ( 2 , 2 , 2 , 2 ) ) ;
153+ let ip3 = IpAddr :: V4 ( Ipv4Addr :: new ( 3 , 3 , 3 , 3 ) ) ;
154+
155+ let mut headers = HeaderMap :: new ( ) ;
156+ headers. insert ( header_name, format ! ( "{},{},{}" , ip1, ip2, ip3) . parse ( ) . unwrap ( ) ) ;
157+
158+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 1 ) . unwrap ( ) ;
159+ assert_eq ! ( ip, ip3) ;
160+
161+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 2 ) . unwrap ( ) ;
162+ assert_eq ! ( ip, ip2) ;
163+
164+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 3 ) . unwrap ( ) ;
165+ assert_eq ! ( ip, ip1) ;
166+
167+ let err = get_ip_from_rightmost_value ( & headers, header_name, 4 )
168+ . expect_err ( "Not enough values should fail" ) ;
169+ assert ! ( matches!( err, IpError :: NotEnoughValues { .. } ) ) ;
170+ }
171+
172+ #[ test]
173+ fn test_rightmost_header_comma_space_separated ( ) {
174+ let header_name = "X-Forwarded-For" ;
175+ let ip1 = IpAddr :: V4 ( Ipv4Addr :: new ( 1 , 1 , 1 , 1 ) ) ;
176+ let ip2 = IpAddr :: V4 ( Ipv4Addr :: new ( 2 , 2 , 2 , 2 ) ) ;
177+ let ip3 = IpAddr :: V4 ( Ipv4Addr :: new ( 3 , 3 , 3 , 3 ) ) ;
178+
179+ let mut headers = HeaderMap :: new ( ) ;
180+ headers. insert ( header_name, format ! ( "{}, {}, {}" , ip1, ip2, ip3) . parse ( ) . unwrap ( ) ) ;
181+
182+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 1 ) . unwrap ( ) ;
183+ assert_eq ! ( ip, ip3) ;
184+
185+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 2 ) . unwrap ( ) ;
186+ assert_eq ! ( ip, ip2) ;
187+
188+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 3 ) . unwrap ( ) ;
189+ assert_eq ! ( ip, ip1) ;
190+
191+ let err = get_ip_from_rightmost_value ( & headers, header_name, 4 )
192+ . expect_err ( "Not enough values should fail" ) ;
193+ assert ! ( matches!( err, IpError :: NotEnoughValues { .. } ) ) ;
194+ }
195+
196+ #[ test]
197+ fn test_rightmost_header_duplicated ( ) {
198+ // If the header appears multiple times, they should be joined together
199+ // as if they were a single value.
200+ let header_name = "X-Forwarded-For" ;
201+ let ip1 = IpAddr :: V4 ( Ipv4Addr :: new ( 1 , 1 , 1 , 1 ) ) ;
202+ let ip2 = IpAddr :: V4 ( Ipv4Addr :: new ( 2 , 2 , 2 , 2 ) ) ;
203+ let ip3 = IpAddr :: V4 ( Ipv4Addr :: new ( 3 , 3 , 3 , 3 ) ) ;
204+ let ip4 = IpAddr :: V4 ( Ipv4Addr :: new ( 4 , 4 , 4 , 4 ) ) ;
205+ let ip5 = IpAddr :: V4 ( Ipv4Addr :: new ( 5 , 5 , 5 , 5 ) ) ;
206+
207+ let mut headers = HeaderMap :: new ( ) ;
208+ headers. insert ( header_name, format ! ( "{},{},{}" , ip1, ip2, ip3) . parse ( ) . unwrap ( ) ) ;
209+ headers. append ( header_name, format ! ( "{},{}" , ip4, ip5) . parse ( ) . unwrap ( ) ) ;
210+
211+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 1 ) . unwrap ( ) ;
212+ assert_eq ! ( ip, ip5) ;
213+
214+ let ip = get_ip_from_rightmost_value ( & headers, header_name, 5 ) . unwrap ( ) ;
215+ assert_eq ! ( ip, ip1) ;
216+
217+ let err = get_ip_from_rightmost_value ( & headers, header_name, 6 )
218+ . expect_err ( "Not enough values should fail" ) ;
219+ assert ! ( matches!( err, IpError :: NotEnoughValues { .. } ) ) ;
220+ }
221+
222+ #[ test]
223+ fn test_rightmost_header_not_present ( ) {
224+ let header_name = "X-Forwarded-For" ;
225+ let headers = HeaderMap :: new ( ) ;
226+
227+ let err = get_ip_from_rightmost_value ( & headers, header_name, 1 )
228+ . expect_err ( "Missing header should fail" ) ;
229+ assert ! ( matches!( err, IpError :: NotPresent ( _) ) ) ;
230+ }
231+
232+ #[ test]
233+ fn test_rightmost_header_invalid_value ( ) {
234+ let header_name = "X-Forwarded-For" ;
235+ let mut headers = HeaderMap :: new ( ) ;
236+ headers. insert ( header_name, "invalid-ip" . parse ( ) . unwrap ( ) ) ;
237+
238+ let err = get_ip_from_rightmost_value ( & headers, header_name, 1 )
239+ . expect_err ( "Invalid IP should fail" ) ;
240+ assert ! ( matches!( err, IpError :: InvalidValue ) ) ;
241+ }
242+ }
0 commit comments