Skip to content

Commit 71f4c9a

Browse files
Fishrock123zkat
andcommitted
feat: h1 connection pooling
This is imported from the orogene project, with massive thanks to Kat Marchán (@zkat) for letting us re-use this work. Specifically: https://github.com/orogene/orogene/tree/82b5c1e6773ceb3ee5a06c8273d7262d6073ce43/crates/oro-client/src/http_client Co-Authored-By: Kat Marchán <[email protected]>
1 parent 06994bb commit 71f4c9a

File tree

4 files changed

+255
-32
lines changed

4 files changed

+255
-32
lines changed

Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ rustdoc-args = ["--cfg", "feature=\"docs\""]
2222
[features]
2323
default = ["h1_client"]
2424
docs = ["h1_client", "curl_client", "wasm_client", "hyper_client"]
25-
h1_client = ["async-h1", "async-std", "async-native-tls"]
26-
h1_client_rustls = ["async-h1", "async-std", "async-tls"]
25+
h1_client = ["async-h1", "async-std", "async-native-tls", "deadpool", "futures"]
26+
h1_client_rustls = ["async-h1", "async-std", "async-tls", "deadpool", "futures"]
2727
native_client = ["curl_client", "wasm_client"]
2828
curl_client = ["isahc", "async-std"]
2929
wasm_client = ["js-sys", "web-sys", "wasm-bindgen", "wasm-bindgen-futures", "futures"]
@@ -38,6 +38,8 @@ log = "0.4.7"
3838
async-h1 = { version = "2.0.0", optional = true }
3939
async-std = { version = "1.6.0", default-features = false, optional = true }
4040
async-native-tls = { version = "0.3.1", optional = true }
41+
deadpool = { version = "0.6.0", optional = true }
42+
futures = { version = "0.3.8", optional = true }
4143

4244
# h1_client_rustls
4345
async-tls = { version = "0.10.0", optional = true }

src/h1.rs renamed to src/h1/mod.rs

Lines changed: 90 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,45 @@
1-
//! http-client implementation for async-h1.
1+
//! http-client implementation for async-h1, with connecton pooling ("Keep-Alive").
22
3-
use super::{async_trait, Error, HttpClient, Request, Response};
3+
use std::collections::HashMap;
4+
use std::net::SocketAddr;
5+
use std::{fmt::Debug, sync::Arc};
46

57
use async_h1::client;
8+
use async_std::net::TcpStream;
9+
use async_std::sync::Mutex;
10+
use deadpool::managed::Pool;
611
use http_types::StatusCode;
712

8-
/// Async-h1 based HTTP Client.
9-
#[derive(Debug)]
13+
#[cfg(not(feature = "h1_client_rustls"))]
14+
use async_native_tls::TlsStream;
15+
#[cfg(feature = "h1_client_rustls")]
16+
use async_tls::client::TlsStream;
17+
18+
use super::{async_trait, Error, HttpClient, Request, Response};
19+
20+
mod tcp;
21+
mod tls;
22+
23+
use tcp::{TcpConnWrapper, TcpConnection};
24+
use tls::{TlsConnWrapper, TlsConnection};
25+
26+
// TODO: Move this to a parameter. This current number is based on a few
27+
// random benchmarks and see whatever gave decent perf vs resource use.
28+
static MAX_CONCURRENT_CONNECTIONS: usize = 50;
29+
30+
type HttpPool = HashMap<SocketAddr, Pool<TcpStream, std::io::Error>>;
31+
type HttpsPool = HashMap<SocketAddr, Pool<TlsStream<TcpStream>, Error>>;
32+
33+
/// Async-h1 based HTTP Client, with connecton pooling ("Keep-Alive").
1034
pub struct H1Client {
11-
_priv: (),
35+
http_pool: Arc<Mutex<HttpPool>>,
36+
https_pool: Arc<Mutex<HttpsPool>>,
37+
}
38+
39+
impl Debug for H1Client {
40+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41+
f.write_str("H1Client")
42+
}
1243
}
1344

1445
impl Default for H1Client {
@@ -20,13 +51,20 @@ impl Default for H1Client {
2051
impl H1Client {
2152
/// Create a new instance.
2253
pub fn new() -> Self {
23-
Self { _priv: () }
54+
Self {
55+
http_pool: Arc::new(Mutex::new(HashMap::new())),
56+
https_pool: Arc::new(Mutex::new(HashMap::new())),
57+
}
2458
}
2559
}
2660

2761
#[async_trait]
2862
impl HttpClient for H1Client {
2963
async fn send(&self, mut req: Request) -> Result<Response, Error> {
64+
let http_pool = self.http_pool.clone();
65+
let https_pool = self.https_pool.clone();
66+
req.insert_header("Connection", "keep-alive");
67+
3068
// Insert host
3169
let host = req
3270
.url()
@@ -57,40 +95,62 @@ impl HttpClient for H1Client {
5795

5896
match scheme {
5997
"http" => {
60-
let stream = async_std::net::TcpStream::connect(addr).await?;
98+
let mut hash = http_pool.lock().await;
99+
let pool = if let Some(pool) = hash.get(&addr) {
100+
pool
101+
} else {
102+
let manager = TcpConnection::new(addr);
103+
let pool =
104+
Pool::<TcpStream, std::io::Error>::new(manager, MAX_CONCURRENT_CONNECTIONS);
105+
hash.insert(addr, pool);
106+
hash.get(&addr).unwrap()
107+
};
108+
let pool = pool.clone();
109+
std::mem::drop(hash);
110+
let stream = pool.get().await?;
61111
req.set_peer_addr(stream.peer_addr().ok());
62112
req.set_local_addr(stream.local_addr().ok());
63-
client::connect(stream, req).await
113+
client::connect(TcpConnWrapper::new(stream), req).await
114+
115+
// let stream = async_std::net::TcpStream::connect(addr).await?;
116+
// req.set_peer_addr(stream.peer_addr().ok());
117+
// req.set_local_addr(stream.local_addr().ok());
118+
// client::connect(stream, req).await
64119
}
65120
"https" => {
66-
let raw_stream = async_std::net::TcpStream::connect(addr).await?;
67-
req.set_peer_addr(raw_stream.peer_addr().ok());
68-
req.set_local_addr(raw_stream.local_addr().ok());
69-
let tls_stream = add_tls(host, raw_stream).await?;
70-
client::connect(tls_stream, req).await
121+
let mut hash = https_pool.lock().await;
122+
let pool = if let Some(pool) = hash.get(&addr) {
123+
pool
124+
} else {
125+
let manager = TlsConnection::new(host.clone(), addr);
126+
let pool = Pool::<TlsStream<TcpStream>, Error>::new(
127+
manager,
128+
MAX_CONCURRENT_CONNECTIONS,
129+
);
130+
hash.insert(addr, pool);
131+
hash.get(&addr).unwrap()
132+
};
133+
let pool = pool.clone();
134+
std::mem::drop(hash);
135+
let stream = pool.get().await.unwrap(); // TODO: remove unwrap
136+
req.set_peer_addr(stream.get_ref().peer_addr().ok());
137+
req.set_local_addr(stream.get_ref().local_addr().ok());
138+
139+
client::connect(TlsConnWrapper::new(stream), req).await
140+
141+
// let raw_stream = async_std::net::TcpStream::connect(addr).await?;
142+
// req.set_peer_addr(raw_stream.peer_addr().ok());
143+
// req.set_local_addr(raw_stream.local_addr().ok());
144+
145+
// let stream = async_native_tls::connect(host, raw_stream).await?;
146+
147+
// client::connect(stream, req).await
71148
}
72149
_ => unreachable!(),
73150
}
74151
}
75152
}
76153

77-
#[cfg(not(feature = "h1_client_rustls"))]
78-
async fn add_tls(
79-
host: String,
80-
stream: async_std::net::TcpStream,
81-
) -> Result<async_native_tls::TlsStream<async_std::net::TcpStream>, async_native_tls::Error> {
82-
async_native_tls::connect(host, stream).await
83-
}
84-
85-
#[cfg(feature = "h1_client_rustls")]
86-
async fn add_tls(
87-
host: String,
88-
stream: async_std::net::TcpStream,
89-
) -> std::io::Result<async_tls::client::TlsStream<async_std::net::TcpStream>> {
90-
let connector = async_tls::TlsConnector::default();
91-
connector.connect(host, stream).await
92-
}
93-
94154
#[cfg(test)]
95155
mod tests {
96156
use super::*;

src/h1/tcp.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use std::fmt::Debug;
2+
use std::net::SocketAddr;
3+
use std::pin::Pin;
4+
5+
use async_std::net::TcpStream;
6+
use async_trait::async_trait;
7+
use deadpool::managed::{Manager, Object, RecycleResult};
8+
use futures::io::{AsyncRead, AsyncWrite};
9+
use futures::task::{Context, Poll};
10+
11+
#[derive(Clone, Debug)]
12+
pub(crate) struct TcpConnection {
13+
addr: SocketAddr,
14+
}
15+
impl TcpConnection {
16+
pub(crate) fn new(addr: SocketAddr) -> Self {
17+
Self { addr }
18+
}
19+
}
20+
21+
pub(crate) struct TcpConnWrapper {
22+
conn: Object<TcpStream, std::io::Error>,
23+
}
24+
impl TcpConnWrapper {
25+
pub(crate) fn new(conn: Object<TcpStream, std::io::Error>) -> Self {
26+
Self { conn }
27+
}
28+
}
29+
30+
impl AsyncRead for TcpConnWrapper {
31+
fn poll_read(
32+
mut self: Pin<&mut Self>,
33+
cx: &mut Context<'_>,
34+
buf: &mut [u8],
35+
) -> Poll<Result<usize, std::io::Error>> {
36+
Pin::new(&mut *self.conn).poll_read(cx, buf)
37+
}
38+
}
39+
40+
impl AsyncWrite for TcpConnWrapper {
41+
fn poll_write(
42+
mut self: Pin<&mut Self>,
43+
cx: &mut Context<'_>,
44+
buf: &[u8],
45+
) -> Poll<std::io::Result<usize>> {
46+
let amt = futures::ready!(Pin::new(&mut *self.conn).poll_write(cx, buf))?;
47+
Poll::Ready(Ok(amt))
48+
}
49+
50+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
51+
Pin::new(&mut *self.conn).poll_flush(cx)
52+
}
53+
54+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
55+
Pin::new(&mut *self.conn).poll_close(cx)
56+
}
57+
}
58+
59+
#[async_trait]
60+
impl Manager<TcpStream, std::io::Error> for TcpConnection {
61+
async fn create(&self) -> Result<TcpStream, std::io::Error> {
62+
Ok(TcpStream::connect(self.addr).await?)
63+
}
64+
65+
async fn recycle(&self, _conn: &mut TcpStream) -> RecycleResult<std::io::Error> {
66+
Ok(())
67+
}
68+
}

src/h1/tls.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use std::fmt::Debug;
2+
use std::net::SocketAddr;
3+
use std::pin::Pin;
4+
5+
use async_std::net::TcpStream;
6+
use async_trait::async_trait;
7+
use deadpool::managed::{Manager, Object, RecycleResult};
8+
use futures::io::{AsyncRead, AsyncWrite};
9+
use futures::task::{Context, Poll};
10+
11+
#[cfg(not(feature = "h1_client_rustls"))]
12+
use async_native_tls::TlsStream;
13+
#[cfg(feature = "h1_client_rustls")]
14+
use async_tls::client::TlsStream;
15+
16+
use crate::Error;
17+
18+
#[derive(Clone, Debug)]
19+
pub(crate) struct TlsConnection {
20+
host: String,
21+
addr: SocketAddr,
22+
}
23+
impl TlsConnection {
24+
pub(crate) fn new(host: String, addr: SocketAddr) -> Self {
25+
Self { host, addr }
26+
}
27+
}
28+
29+
pub(crate) struct TlsConnWrapper {
30+
conn: Object<TlsStream<TcpStream>, Error>,
31+
}
32+
impl TlsConnWrapper {
33+
pub(crate) fn new(conn: Object<TlsStream<TcpStream>, Error>) -> Self {
34+
Self { conn }
35+
}
36+
}
37+
38+
impl AsyncRead for TlsConnWrapper {
39+
fn poll_read(
40+
mut self: Pin<&mut Self>,
41+
cx: &mut Context<'_>,
42+
buf: &mut [u8],
43+
) -> Poll<Result<usize, std::io::Error>> {
44+
Pin::new(&mut *self.conn).poll_read(cx, buf)
45+
}
46+
}
47+
48+
impl AsyncWrite for TlsConnWrapper {
49+
fn poll_write(
50+
mut self: Pin<&mut Self>,
51+
cx: &mut Context<'_>,
52+
buf: &[u8],
53+
) -> Poll<std::io::Result<usize>> {
54+
let amt = futures::ready!(Pin::new(&mut *self.conn).poll_write(cx, buf))?;
55+
Poll::Ready(Ok(amt))
56+
}
57+
58+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
59+
Pin::new(&mut *self.conn).poll_flush(cx)
60+
}
61+
62+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
63+
Pin::new(&mut *self.conn).poll_close(cx)
64+
}
65+
}
66+
67+
#[async_trait]
68+
impl Manager<TlsStream<TcpStream>, Error> for TlsConnection {
69+
async fn create(&self) -> Result<TlsStream<TcpStream>, Error> {
70+
log::trace!("Creating new socket to {:?}", self.addr);
71+
let raw_stream = async_std::net::TcpStream::connect(self.addr).await?;
72+
let tls_stream = add_tls(&self.host, raw_stream).await?;
73+
Ok(tls_stream)
74+
}
75+
76+
async fn recycle(&self, _conn: &mut TlsStream<TcpStream>) -> RecycleResult<Error> {
77+
Ok(())
78+
}
79+
}
80+
81+
#[cfg(not(feature = "h1_client_rustls"))]
82+
async fn add_tls(
83+
host: &str,
84+
stream: TcpStream,
85+
) -> Result<async_native_tls::TlsStream<TcpStream>, async_native_tls::Error> {
86+
async_native_tls::connect(host, stream).await
87+
}
88+
89+
#[cfg(feature = "h1_client_rustls")]
90+
async fn add_tls(host: &str, stream: TcpStream) -> Result<TlsStream<TcpStream>, std::io::Error> {
91+
let connector = async_tls::TlsConnector::default();
92+
connector.connect(host, stream).await
93+
}

0 commit comments

Comments
 (0)