Skip to content

Commit a10b702

Browse files
authored
Merge pull request #115 from hatoo/conn-pool
Conn pool
2 parents e4a1e86 + 6334119 commit a10b702

File tree

3 files changed

+235
-43
lines changed

3 files changed

+235
-43
lines changed

src/default_client.rs

Lines changed: 207 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]
22

3-
use bytes::Bytes;
4-
use http_body_util::Empty;
3+
use bytes::{Buf, Bytes};
4+
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
55
use hyper::{
66
Request, Response, StatusCode, Uri, Version,
77
body::{Body, Incoming},
88
client, header,
99
};
1010
use hyper_util::rt::{TokioExecutor, TokioIo};
11-
use std::task::{Context, Poll};
11+
use std::{
12+
collections::HashMap,
13+
future::poll_fn,
14+
sync::Arc,
15+
task::{Context, Poll},
16+
};
17+
use tokio::sync::Mutex;
1218
use tokio::{net::TcpStream, task::JoinHandle};
1319

1420
#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
@@ -55,6 +61,96 @@ pub struct Upgraded {
5561
/// A socket to Server
5662
pub server: TokioIo<hyper::upgrade::Upgraded>,
5763
}
64+
65+
type DynError = Box<dyn std::error::Error + Send + Sync>;
66+
type PooledBody = BoxBody<Bytes, DynError>;
67+
type Http1Sender = hyper::client::conn::http1::SendRequest<PooledBody>;
68+
type Http2Sender = hyper::client::conn::http2::SendRequest<PooledBody>;
69+
70+
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
71+
enum ConnectionProtocol {
72+
Http1,
73+
Http2,
74+
}
75+
76+
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
77+
struct ConnectionKey {
78+
host: String,
79+
port: u16,
80+
is_tls: bool,
81+
protocol: ConnectionProtocol,
82+
}
83+
84+
impl ConnectionKey {
85+
fn new(host: String, port: u16, is_tls: bool, protocol: ConnectionProtocol) -> Self {
86+
Self {
87+
host,
88+
port,
89+
is_tls,
90+
protocol,
91+
}
92+
}
93+
94+
fn from_uri(uri: &Uri, protocol: ConnectionProtocol) -> Result<Self, Error> {
95+
let (host, port, is_tls) = host_port(uri)?;
96+
Ok(ConnectionKey::new(host, port, is_tls, protocol))
97+
}
98+
}
99+
100+
#[derive(Clone, Default)]
101+
struct ConnectionPool {
102+
http1: Arc<Mutex<HashMap<ConnectionKey, Vec<Http1Sender>>>>,
103+
http2: Arc<Mutex<HashMap<ConnectionKey, Http2Sender>>>,
104+
}
105+
106+
impl ConnectionPool {
107+
async fn take_http1(&self, key: &ConnectionKey) -> Option<Http1Sender> {
108+
let mut guard = self.http1.lock().await;
109+
let entry = guard.get_mut(key)?;
110+
while let Some(mut conn) = entry.pop() {
111+
if sender_alive_http1(&mut conn).await {
112+
return Some(conn);
113+
}
114+
}
115+
if entry.is_empty() {
116+
guard.remove(key);
117+
}
118+
None
119+
}
120+
121+
async fn put_http1(&self, key: ConnectionKey, sender: Http1Sender) {
122+
let mut guard = self.http1.lock().await;
123+
guard.entry(key).or_default().push(sender);
124+
}
125+
126+
async fn get_http2(&self, key: &ConnectionKey) -> Option<Http2Sender> {
127+
let mut guard = self.http2.lock().await;
128+
let mut sender = guard.get(key).cloned()?;
129+
130+
let alive = sender_alive_http2(&mut sender).await;
131+
132+
if alive {
133+
Some(sender)
134+
} else {
135+
guard.remove(key);
136+
None
137+
}
138+
}
139+
140+
async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Http2Sender) {
141+
let mut guard = self.http2.lock().await;
142+
guard.entry(key).or_insert(sender);
143+
}
144+
}
145+
146+
async fn sender_alive_http1(sender: &mut Http1Sender) -> bool {
147+
poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
148+
}
149+
150+
async fn sender_alive_http2(sender: &mut Http2Sender) -> bool {
151+
poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
152+
}
153+
58154
#[derive(Clone)]
59155
/// Default HTTP client for this crate
60156
pub struct DefaultClient {
@@ -71,6 +167,8 @@ pub struct DefaultClient {
71167
/// If true, send_request will returns an Upgraded struct when the response is an upgrade
72168
/// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
73169
pub with_upgrades: bool,
170+
171+
pool: ConnectionPool,
74172
}
75173
impl Default for DefaultClient {
76174
fn default() -> Self {
@@ -102,6 +200,7 @@ impl DefaultClient {
102200
tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
103201
tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
104202
with_upgrades: false,
203+
pool: ConnectionPool::default(),
105204
})
106205
}
107206

@@ -135,6 +234,7 @@ impl DefaultClient {
135234
tls_connector_alpn_h2,
136235
)),
137236
with_upgrades: false,
237+
pool: ConnectionPool::default(),
138238
})
139239
}
140240

@@ -175,17 +275,54 @@ impl DefaultClient {
175275
Error,
176276
>
177277
where
178-
B: Body + Unpin + Send + 'static,
179-
B::Data: Send,
180-
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
278+
B: Body<Data = Bytes> + Send + Sync + 'static,
279+
B::Data: Send + Buf,
280+
B::Error: Into<DynError>,
181281
{
182-
let mut send_request = self.connect(req.uri(), req.version()).await?;
282+
let target_uri = req.uri().clone();
283+
let mut send_request = if req.version() == Version::HTTP_2 {
284+
match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2) {
285+
Ok(pool_key) => {
286+
if let Some(conn) = self.pool.get_http2(&pool_key).await {
287+
SendRequest::Http2(conn)
288+
} else {
289+
self.connect(req.uri(), req.version(), Some(pool_key))
290+
.await?
291+
}
292+
}
293+
Err(err) => {
294+
tracing::warn!(
295+
"ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool",
296+
err
297+
);
298+
self.connect(req.uri(), req.version(), None).await?
299+
}
300+
}
301+
} else {
302+
match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) {
303+
Ok(pool_key) => {
304+
if let Some(conn) = self.pool.take_http1(&pool_key).await {
305+
SendRequest::Http1(conn)
306+
} else {
307+
self.connect(req.uri(), req.version(), Some(pool_key))
308+
.await?
309+
}
310+
}
311+
Err(err) => {
312+
tracing::warn!(
313+
"ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool",
314+
err
315+
);
316+
self.connect(req.uri(), req.version(), None).await?
317+
}
318+
}
319+
};
183320

184321
let (req_parts, req_body) = req.into_parts();
185322

186-
let res = send_request
187-
.send_request(Request::from_parts(req_parts.clone(), req_body))
188-
.await?;
323+
let boxed_req = Request::from_parts(req_parts.clone(), to_boxed_body(req_body));
324+
325+
let res = send_request.send_request(boxed_req).await?;
189326

190327
if res.status() == StatusCode::SWITCHING_PROTOCOLS {
191328
let (res_parts, res_body) = res.into_parts();
@@ -221,36 +358,41 @@ impl DefaultClient {
221358

222359
Ok((Response::from_parts(res_parts, res_body), upgrade))
223360
} else {
361+
match send_request {
362+
SendRequest::Http1(sender) => {
363+
if let Ok(pool_key) =
364+
ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)
365+
{
366+
self.pool.put_http1(pool_key, sender).await;
367+
} else {
368+
// If we couldn't build a pool key, skip pooling.
369+
}
370+
}
371+
SendRequest::Http2(_) => {
372+
// For HTTP/2 the pool retains a shared sender; no action needed.
373+
}
374+
}
224375
Ok((res, None))
225376
}
226377
}
227378

228-
async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
229-
where
230-
B: Body + Unpin + Send + 'static,
231-
B::Data: Send,
232-
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
233-
{
234-
let host = uri
235-
.host()
236-
.ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?;
237-
let port =
238-
uri.port_u16()
239-
.unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
240-
443
241-
} else {
242-
80
243-
});
379+
async fn connect(
380+
&self,
381+
uri: &Uri,
382+
http_version: Version,
383+
key: Option<ConnectionKey>,
384+
) -> Result<SendRequest, Error> {
385+
let (host, port, is_tls) = host_port(uri)?;
244386

245-
let tcp = TcpStream::connect((host, port)).await?;
387+
let tcp = TcpStream::connect((host.as_str(), port)).await?;
246388
// This is actually needed to some servers
247389
let _ = tcp.set_nodelay(true);
248390

249-
if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
391+
if is_tls {
250392
#[cfg(feature = "native-tls-client")]
251393
let tls = self
252394
.tls_connector(http_version)
253-
.connect(host, tcp)
395+
.connect(&host, tcp)
254396
.await
255397
.map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
256398
#[cfg(feature = "rustls-client")]
@@ -284,6 +426,14 @@ impl DefaultClient {
284426

285427
tokio::spawn(conn);
286428

429+
if let Some(ref k) = key
430+
&& matches!(k.protocol, ConnectionProtocol::Http2)
431+
{
432+
self.pool
433+
.insert_http2_if_absent(k.clone(), sender.clone())
434+
.await;
435+
}
436+
287437
Ok(SendRequest::Http2(sender))
288438
} else {
289439
let (sender, conn) = client::conn::http1::Builder::new()
@@ -310,18 +460,15 @@ impl DefaultClient {
310460
}
311461
}
312462

313-
enum SendRequest<B> {
314-
Http1(hyper::client::conn::http1::SendRequest<B>),
315-
Http2(hyper::client::conn::http2::SendRequest<B>),
463+
enum SendRequest {
464+
Http1(Http1Sender),
465+
Http2(Http2Sender),
316466
}
317467

318-
impl<B> SendRequest<B>
319-
where
320-
B: Body + 'static,
321-
{
468+
impl SendRequest {
322469
async fn send_request(
323470
&mut self,
324-
mut req: Request<B>,
471+
mut req: Request<PooledBody>,
325472
) -> Result<Response<Incoming>, hyper::Error> {
326473
match self {
327474
SendRequest::Http1(sender) => {
@@ -357,13 +504,13 @@ where
357504
}
358505
}
359506

360-
impl<B> SendRequest<B> {
507+
impl SendRequest {
361508
#[allow(dead_code)]
362509
// TODO: connection pooling
363510
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
364511
match self {
365512
SendRequest::Http1(sender) => sender.poll_ready(cx),
366-
SendRequest::Http2(sender) => sender.poll_ready(cx),
513+
SendRequest::Http2(_sender) => Poll::Ready(Ok(())),
367514
}
368515
}
369516
}
@@ -375,3 +522,24 @@ fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::Inv
375522
*req.uri_mut() = Uri::from_parts(parts)?;
376523
Ok(())
377524
}
525+
526+
fn to_boxed_body<B>(body: B) -> PooledBody
527+
where
528+
B: Body<Data = Bytes> + Send + Sync + 'static,
529+
B::Data: Send + Buf,
530+
B::Error: Into<DynError>,
531+
{
532+
body.map_err(|err| err.into()).boxed()
533+
}
534+
535+
fn host_port(uri: &Uri) -> Result<(String, u16, bool), Error> {
536+
let host = uri
537+
.host()
538+
.ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?
539+
.to_string();
540+
let is_tls = uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS);
541+
let port = uri.port_u16().unwrap_or(if is_tls { 443 } else { 80 });
542+
Ok((host, port, is_tls))
543+
}
544+
545+
impl DefaultClient {}

src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ where
134134
<S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
135135
S::Future: Send,
136136
{
137-
service_fn(move |req| {
137+
service_fn(move |mut req| {
138138
let proxy = proxy.clone();
139139
let mut service = service.clone();
140140

@@ -151,6 +151,7 @@ where
151151
};
152152

153153
tokio::spawn(async move {
154+
let remote_addr: Option<RemoteAddr> = req.extensions_mut().remove();
154155
let client = match hyper::upgrade::on(req).await {
155156
Ok(client) => client,
156157
Err(err) => {
@@ -194,6 +195,9 @@ where
194195
let mut service = service.clone();
195196

196197
async move {
198+
if let Some(remote_addr) = remote_addr {
199+
req.extensions_mut().insert(remote_addr);
200+
}
197201
inject_authority(&mut req, connect_authority.clone());
198202
service.call(req).await
199203
}

0 commit comments

Comments
 (0)