Skip to content

Commit 8626786

Browse files
authored
Merge pull request #131 from jbr/host
Request::host
2 parents ba2ffcd + 398cb65 commit 8626786

File tree

1 file changed

+152
-71
lines changed

1 file changed

+152
-71
lines changed

src/request.rs

Lines changed: 152 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -93,25 +93,48 @@ impl Request {
9393
}
9494

9595
/// Get the remote address for this request.
96+
/// This is determined in the following priority:
97+
/// 1. `Forwarded` header `for` key
98+
/// 2. The first `X-Forwarded-For` header
99+
/// 3. Peer address of the transport
96100
pub fn remote(&self) -> Option<&str> {
97-
self.forwarded_for().or(self.peer_addr())
101+
self.forwarded_for().or_else(|| self.peer_addr())
102+
}
103+
104+
/// Get the destination host for this request.
105+
/// This is determined in the following priority:
106+
/// 1. `Forwarded` header `host` key
107+
/// 2. The first `X-Forwarded-Host` header
108+
/// 3. `Host` header
109+
/// 4. URL domain, if any
110+
pub fn host(&self) -> Option<&str> {
111+
self.forwarded_header_part("host")
112+
.or_else(|| {
113+
self.header("X-Forwarded-Host")
114+
.and_then(|h| h.as_str().split(",").next())
115+
})
116+
.or_else(|| self.header(&headers::HOST).map(|h| h.as_str()))
117+
.or_else(|| self.url().host_str())
98118
}
99119

100-
fn forwarded_for(&self) -> Option<&str> {
101-
if let Some(header) = self.header("Forwarded") {
120+
fn forwarded_header_part(&self, part: &str) -> Option<&str> {
121+
self.header("Forwarded").and_then(|header| {
102122
header.as_str().split(";").find_map(|key_equals_value| {
103123
let parts = key_equals_value.split("=").collect::<Vec<_>>();
104-
if parts.len() == 2 && parts[0].eq_ignore_ascii_case("for") {
124+
if parts.len() == 2 && parts[0].eq_ignore_ascii_case(part) {
105125
Some(parts[1])
106126
} else {
107127
None
108128
}
109129
})
110-
} else if let Some(header) = self.header("X-Forwarded-For") {
111-
header.as_str().split(",").next()
112-
} else {
113-
None
114-
}
130+
})
131+
}
132+
133+
fn forwarded_for(&self) -> Option<&str> {
134+
self.forwarded_header_part("for").or_else(|| {
135+
self.header("X-Forwarded-For")
136+
.and_then(|header| header.as_str().split(",").next())
137+
})
115138
}
116139

117140
/// Get the HTTP method
@@ -689,85 +712,143 @@ impl<'a> IntoIterator for &'a mut Request {
689712
#[cfg(test)]
690713
mod tests {
691714
use super::*;
715+
mod host {
716+
use super::*;
717+
718+
#[test]
719+
fn when_forwarded_header_is_set() {
720+
let mut request = build_test_request();
721+
set_forwarded(&mut request, "-");
722+
set_x_forwarded_host(&mut request, "this will not be used");
723+
assert_eq!(request.forwarded_header_part("host"), Some("host.com"));
724+
assert_eq!(request.host(), Some("host.com"));
725+
}
692726

693-
fn build_test_request() -> Request {
694-
let url = Url::parse("http://irrelevant/").unwrap();
695-
Request::new(Method::Get, url)
696-
}
727+
#[test]
728+
fn when_several_x_forwarded_hosts_exist() {
729+
let mut request = build_test_request();
730+
set_x_forwarded_host(&mut request, "expected.host");
697731

698-
fn set_x_forwarded_for(request: &mut Request, client: &'static str) {
699-
request.insert_header(
700-
"x-forwarded-for",
701-
format!("{},proxy.com,other-proxy.com", client),
702-
);
703-
}
732+
assert_eq!(request.forwarded_header_part("host"), None);
733+
assert_eq!(request.host(), Some("expected.host"));
734+
}
704735

705-
fn set_forwarded(request: &mut Request, client: &'static str) {
706-
request.insert_header(
707-
"Forwarded",
708-
format!("by=something.com;for={};host=host.com;proto=http", client),
709-
);
710-
}
736+
#[test]
737+
fn when_only_one_x_forwarded_hosts_exist() {
738+
let mut request = build_test_request();
739+
request.insert_header("x-forwarded-host", "expected.host");
740+
assert_eq!(request.host(), Some("expected.host"));
741+
}
711742

712-
#[test]
713-
fn test_remote_and_forwarded_for_when_forwarded_is_properly_formatted() {
714-
let mut request = build_test_request();
715-
request.set_peer_addr(Some("127.0.0.1:8000"));
716-
set_forwarded(&mut request, "127.0.0.1:8001");
743+
#[test]
744+
fn when_host_header_is_set() {
745+
let mut request = build_test_request();
746+
request.insert_header("host", "host.header");
747+
assert_eq!(request.host(), Some("host.header"));
748+
}
717749

718-
assert_eq!(request.forwarded_for(), Some("127.0.0.1:8001"));
719-
assert_eq!(request.remote(), Some("127.0.0.1:8001"));
750+
#[test]
751+
fn when_there_are_no_headers() {
752+
let request = build_test_request();
753+
assert_eq!(request.host(), Some("async.rs"));
754+
}
755+
756+
#[test]
757+
fn when_url_has_no_domain() {
758+
let mut request = build_test_request();
759+
*request.url_mut() = Url::parse("x:").unwrap();
760+
assert_eq!(request.host(), None);
761+
}
720762
}
721763

722-
#[test]
723-
fn test_remote_and_forwarded_for_when_forwarded_is_improperly_formatted() {
724-
let mut request = build_test_request();
725-
request.set_peer_addr(Some(
726-
"127.0.0.1:8000".parse::<std::net::SocketAddr>().unwrap(),
727-
));
764+
mod remote {
765+
use super::*;
766+
#[test]
767+
fn when_forwarded_is_properly_formatted() {
768+
let mut request = build_test_request();
769+
request.set_peer_addr(Some("127.0.0.1:8000"));
770+
set_forwarded(&mut request, "127.0.0.1:8001");
771+
772+
assert_eq!(request.forwarded_for(), Some("127.0.0.1:8001"));
773+
assert_eq!(request.remote(), Some("127.0.0.1:8001"));
774+
}
728775

729-
request.insert_header("Forwarded", "this is an improperly ;;; formatted header");
776+
#[test]
777+
fn when_forwarded_is_improperly_formatted() {
778+
let mut request = build_test_request();
779+
request.set_peer_addr(Some(
780+
"127.0.0.1:8000".parse::<std::net::SocketAddr>().unwrap(),
781+
));
730782

731-
assert_eq!(request.forwarded_for(), None);
732-
assert_eq!(request.remote(), Some("127.0.0.1:8000"));
733-
}
783+
request.insert_header("Forwarded", "this is an improperly ;;; formatted header");
734784

735-
#[test]
736-
fn test_remote_and_forwarded_for_when_x_forwarded_for_is_set() {
737-
let mut request = build_test_request();
738-
request.set_peer_addr(Some(
739-
std::path::PathBuf::from("/dev/random").to_str().unwrap(),
740-
));
741-
set_x_forwarded_for(&mut request, "forwarded-host.com");
785+
assert_eq!(request.forwarded_for(), None);
786+
assert_eq!(request.remote(), Some("127.0.0.1:8000"));
787+
}
742788

743-
assert_eq!(request.forwarded_for(), Some("forwarded-host.com"));
744-
assert_eq!(request.remote(), Some("forwarded-host.com"));
745-
}
789+
#[test]
790+
fn when_x_forwarded_for_is_set() {
791+
let mut request = build_test_request();
792+
request.set_peer_addr(Some(
793+
std::path::PathBuf::from("/dev/random").to_str().unwrap(),
794+
));
795+
set_x_forwarded_for(&mut request, "forwarded-host.com");
746796

747-
#[test]
748-
fn test_remote_and_forwarded_for_when_both_forwarding_headers_are_set() {
749-
let mut request = build_test_request();
750-
set_forwarded(&mut request, "forwarded.com");
751-
set_x_forwarded_for(&mut request, "forwarded-for-client.com");
752-
request.peer_addr = Some("127.0.0.1:8000".into());
797+
assert_eq!(request.forwarded_for(), Some("forwarded-host.com"));
798+
assert_eq!(request.remote(), Some("forwarded-host.com"));
799+
}
800+
801+
#[test]
802+
fn when_both_forwarding_headers_are_set() {
803+
let mut request = build_test_request();
804+
set_forwarded(&mut request, "forwarded.com");
805+
set_x_forwarded_for(&mut request, "forwarded-for-client.com");
806+
request.peer_addr = Some("127.0.0.1:8000".into());
807+
808+
assert_eq!(request.forwarded_for(), Some("forwarded.com".into()));
809+
assert_eq!(request.remote(), Some("forwarded.com".into()));
810+
}
753811

754-
assert_eq!(request.forwarded_for(), Some("forwarded.com".into()));
755-
assert_eq!(request.remote(), Some("forwarded.com".into()));
812+
#[test]
813+
fn falling_back_to_peer_addr() {
814+
let mut request = build_test_request();
815+
request.peer_addr = Some("127.0.0.1:8000".into());
816+
817+
assert_eq!(request.forwarded_for(), None);
818+
assert_eq!(request.remote(), Some("127.0.0.1:8000".into()));
819+
}
820+
821+
#[test]
822+
fn when_no_remote_available() {
823+
let request = build_test_request();
824+
assert_eq!(request.forwarded_for(), None);
825+
assert_eq!(request.remote(), None);
826+
}
756827
}
757828

758-
#[test]
759-
fn test_remote_falling_back_to_peer_addr() {
760-
let mut request = build_test_request();
761-
request.peer_addr = Some("127.0.0.1:8000".into());
829+
fn build_test_request() -> Request {
830+
let url = Url::parse("http://async.rs/").unwrap();
831+
Request::new(Method::Get, url)
832+
}
762833

763-
assert_eq!(request.forwarded_for(), None);
764-
assert_eq!(request.remote(), Some("127.0.0.1:8000".into()));
834+
fn set_x_forwarded_for(request: &mut Request, client: &'static str) {
835+
request.insert_header(
836+
"x-forwarded-for",
837+
format!("{},proxy.com,other-proxy.com", client),
838+
);
839+
}
840+
841+
fn set_x_forwarded_host(request: &mut Request, host: &'static str) {
842+
request.insert_header(
843+
"x-forwarded-host",
844+
format!("{},proxy.com,other-proxy.com", host),
845+
);
765846
}
766847

767-
#[test]
768-
fn test_remote_and_forwarded_for_when_no_remote_available() {
769-
let request = build_test_request();
770-
assert_eq!(request.forwarded_for(), None);
771-
assert_eq!(request.remote(), None);
848+
fn set_forwarded(request: &mut Request, client: &'static str) {
849+
request.insert_header(
850+
"Forwarded",
851+
format!("by=something.com;for={};host=host.com;proto=http", client),
852+
);
772853
}
773854
}

0 commit comments

Comments
 (0)