Skip to content

Commit e5bbc27

Browse files
committed
h1: Fix connection with multiple IPs for a hostname
When trying to connect to multiple IPs for a hostname (e.g. IPv4 and IPv6) we ought to try all prior returning error. Running a wget to the running mockito server has this output: ,---- | $ wget -O- http://localhost:1234/report | --2021-03-08 16:13:12-- http://localhost:1234/report | Resolving localhost (localhost)... ::1, 127.0.0.1 | Connecting to localhost (localhost)|::1|:1234... failed: Connection refused. | Connecting to localhost (localhost)|127.0.0.1|:1234... connected. | HTTP request sent, awaiting response... 200 OK `---- Fixes: #79. Signed-off-by: Otavio Salvador <[email protected]>
1 parent db0025d commit e5bbc27

File tree

2 files changed

+89
-61
lines changed

2 files changed

+89
-61
lines changed

src/h1/mod.rs

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -134,72 +134,84 @@ impl HttpClient for H1Client {
134134
));
135135
}
136136

137-
let addr = req
138-
.url()
139-
.socket_addrs(|| match req.url().scheme() {
140-
"http" => Some(80),
141-
#[cfg(any(feature = "native-tls", feature = "rustls"))]
142-
"https" => Some(443),
143-
_ => None,
144-
})?
145-
.into_iter()
146-
.next()
147-
.ok_or_else(|| Error::from_str(StatusCode::BadRequest, "missing valid address"))?;
137+
let addrs = req.url().socket_addrs(|| match req.url().scheme() {
138+
"http" => Some(80),
139+
#[cfg(any(feature = "native-tls", feature = "rustls"))]
140+
"https" => Some(443),
141+
_ => None,
142+
})?;
148143

149144
log::trace!("> Scheme: {}", scheme);
150145

151-
match scheme {
152-
"http" => {
153-
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
154-
pool_ref
155-
} else {
156-
let manager = TcpConnection::new(addr);
157-
let pool = Pool::<TcpStream, std::io::Error>::new(
158-
manager,
159-
self.max_concurrent_connections,
160-
);
161-
self.http_pools.insert(addr, pool);
162-
self.http_pools.get(&addr).unwrap()
163-
};
164-
165-
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
166-
let pool = pool_ref.clone();
167-
std::mem::drop(pool_ref);
168-
169-
let stream = pool.get().await?;
170-
req.set_peer_addr(stream.peer_addr().ok());
171-
req.set_local_addr(stream.local_addr().ok());
172-
client::connect(TcpConnWrapper::new(stream), req).await
173-
}
174-
#[cfg(any(feature = "native-tls", feature = "rustls"))]
175-
"https" => {
176-
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
177-
pool_ref
178-
} else {
179-
let manager = TlsConnection::new(host.clone(), addr);
180-
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
181-
manager,
182-
self.max_concurrent_connections,
183-
);
184-
self.https_pools.insert(addr, pool);
185-
self.https_pools.get(&addr).unwrap()
186-
};
187-
188-
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
189-
let pool = pool_ref.clone();
190-
std::mem::drop(pool_ref);
191-
192-
let stream = pool
193-
.get()
194-
.await
195-
.map_err(|e| Error::from_str(400, e.to_string()))?;
196-
req.set_peer_addr(stream.get_ref().peer_addr().ok());
197-
req.set_local_addr(stream.get_ref().local_addr().ok());
198-
199-
client::connect(TlsConnWrapper::new(stream), req).await
146+
let max_addrs_idx = addrs.len() - 1;
147+
for (idx, addr) in addrs.into_iter().enumerate() {
148+
let has_another_addr = idx != max_addrs_idx;
149+
150+
match scheme {
151+
"http" => {
152+
let pool_ref = if let Some(pool_ref) = self.http_pools.get(&addr) {
153+
pool_ref
154+
} else {
155+
let manager = TcpConnection::new(addr);
156+
let pool = Pool::<TcpStream, std::io::Error>::new(
157+
manager,
158+
self.max_concurrent_connections,
159+
);
160+
self.http_pools.insert(addr, pool);
161+
self.http_pools.get(&addr).unwrap()
162+
};
163+
164+
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
165+
let pool = pool_ref.clone();
166+
std::mem::drop(pool_ref);
167+
168+
let stream = match pool.get().await {
169+
Ok(s) => s,
170+
Err(_) if has_another_addr => continue,
171+
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
172+
};
173+
174+
req.set_peer_addr(stream.peer_addr().ok());
175+
req.set_local_addr(stream.local_addr().ok());
176+
return client::connect(TcpConnWrapper::new(stream), req).await;
177+
}
178+
#[cfg(any(feature = "native-tls", feature = "rustls"))]
179+
"https" => {
180+
let pool_ref = if let Some(pool_ref) = self.https_pools.get(&addr) {
181+
pool_ref
182+
} else {
183+
let manager = TlsConnection::new(host.clone(), addr);
184+
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
185+
manager,
186+
self.max_concurrent_connections,
187+
);
188+
self.https_pools.insert(addr, pool);
189+
self.https_pools.get(&addr).unwrap()
190+
};
191+
192+
// Deadlocks are prevented by cloning an inner pool Arc and dropping the original locking reference before we await.
193+
let pool = pool_ref.clone();
194+
std::mem::drop(pool_ref);
195+
196+
let stream = match pool.get().await {
197+
Ok(s) => s,
198+
Err(_) if has_another_addr => continue,
199+
Err(e) => return Err(Error::from_str(400, e.to_string()))?,
200+
};
201+
202+
req.set_peer_addr(stream.get_ref().peer_addr().ok());
203+
req.set_local_addr(stream.get_ref().local_addr().ok());
204+
205+
return client::connect(TlsConnWrapper::new(stream), req).await;
206+
}
207+
_ => unreachable!(),
200208
}
201-
_ => unreachable!(),
202209
}
210+
211+
Err(Error::from_str(
212+
StatusCode::BadRequest,
213+
"missing valid address",
214+
))
203215
}
204216
}
205217

tests/test.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,19 @@ async fn keep_alive() {
149149
client.send(req.clone()).await.unwrap();
150150
client.send(req.clone()).await.unwrap();
151151
}
152+
153+
#[atest]
154+
async fn fallback_to_ipv4() {
155+
let client = DefaultClient::new();
156+
let _mock_guard = mock("GET", "/")
157+
.with_status(200)
158+
.expect_at_least(2)
159+
.create();
160+
161+
// Kips the initial "http://127.0.0.1:" to get only the port number
162+
let mock_port = &mockito::server_url()[17..];
163+
164+
let url = &format!("http://localhost:{}", mock_port);
165+
let req = Request::new(http_types::Method::Get, Url::parse(url).unwrap());
166+
client.send(req.clone()).await.unwrap();
167+
}

0 commit comments

Comments
 (0)