Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 207 additions & 39 deletions src/default_client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
#![cfg(any(feature = "native-tls-client", feature = "rustls-client"))]

use bytes::Bytes;
use http_body_util::Empty;
use bytes::{Buf, Bytes};
use http_body_util::{BodyExt, Empty, combinators::BoxBody};
use hyper::{
Request, Response, StatusCode, Uri, Version,
body::{Body, Incoming},
client, header,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use std::task::{Context, Poll};
use std::{
collections::HashMap,
future::poll_fn,
sync::Arc,
task::{Context, Poll},
};
use tokio::sync::Mutex;
use tokio::{net::TcpStream, task::JoinHandle};

#[cfg(all(feature = "native-tls-client", feature = "rustls-client"))]
Expand Down Expand Up @@ -55,6 +61,96 @@ pub struct Upgraded {
/// A socket to Server
pub server: TokioIo<hyper::upgrade::Upgraded>,
}

type DynError = Box<dyn std::error::Error + Send + Sync>;
type PooledBody = BoxBody<Bytes, DynError>;
type Http1Sender = hyper::client::conn::http1::SendRequest<PooledBody>;
type Http2Sender = hyper::client::conn::http2::SendRequest<PooledBody>;

#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
enum ConnectionProtocol {
Http1,
Http2,
}

#[derive(Clone, Debug, Eq, PartialEq, Hash)]
struct ConnectionKey {
host: String,
port: u16,
is_tls: bool,
protocol: ConnectionProtocol,
}

impl ConnectionKey {
fn new(host: String, port: u16, is_tls: bool, protocol: ConnectionProtocol) -> Self {
Self {
host,
port,
is_tls,
protocol,
}
}

fn from_uri(uri: &Uri, protocol: ConnectionProtocol) -> Result<Self, Error> {
let (host, port, is_tls) = host_port(uri)?;
Ok(ConnectionKey::new(host, port, is_tls, protocol))
}
}

#[derive(Clone, Default)]
struct ConnectionPool {
http1: Arc<Mutex<HashMap<ConnectionKey, Vec<Http1Sender>>>>,
http2: Arc<Mutex<HashMap<ConnectionKey, Http2Sender>>>,
}

impl ConnectionPool {
async fn take_http1(&self, key: &ConnectionKey) -> Option<Http1Sender> {
let mut guard = self.http1.lock().await;
let entry = guard.get_mut(key)?;
while let Some(mut conn) = entry.pop() {
if sender_alive_http1(&mut conn).await {
return Some(conn);
}
}
if entry.is_empty() {
guard.remove(key);
}
None
}

async fn put_http1(&self, key: ConnectionKey, sender: Http1Sender) {
let mut guard = self.http1.lock().await;
guard.entry(key).or_default().push(sender);
}

async fn get_http2(&self, key: &ConnectionKey) -> Option<Http2Sender> {
let mut guard = self.http2.lock().await;
let mut sender = guard.get(key).cloned()?;

let alive = sender_alive_http2(&mut sender).await;

if alive {
Some(sender)
} else {
guard.remove(key);
None
}
}

async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Http2Sender) {
let mut guard = self.http2.lock().await;
guard.entry(key).or_insert(sender);
}
}

async fn sender_alive_http1(sender: &mut Http1Sender) -> bool {
poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
}

async fn sender_alive_http2(sender: &mut Http2Sender) -> bool {
poll_fn(|cx| sender.poll_ready(cx)).await.is_ok()
}

#[derive(Clone)]
/// Default HTTP client for this crate
pub struct DefaultClient {
Expand All @@ -71,6 +167,8 @@ pub struct DefaultClient {
/// If true, send_request will returns an Upgraded struct when the response is an upgrade
/// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade
pub with_upgrades: bool,

pool: ConnectionPool,
}
impl Default for DefaultClient {
fn default() -> Self {
Expand Down Expand Up @@ -102,6 +200,7 @@ impl DefaultClient {
tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
with_upgrades: false,
pool: ConnectionPool::default(),
})
}

Expand Down Expand Up @@ -135,6 +234,7 @@ impl DefaultClient {
tls_connector_alpn_h2,
)),
with_upgrades: false,
pool: ConnectionPool::default(),
})
}

Expand Down Expand Up @@ -175,17 +275,54 @@ impl DefaultClient {
Error,
>
where
B: Body + Unpin + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
B: Body<Data = Bytes> + Send + Sync + 'static,
B::Data: Send + Buf,
B::Error: Into<DynError>,
{
let mut send_request = self.connect(req.uri(), req.version()).await?;
let target_uri = req.uri().clone();
let mut send_request = if req.version() == Version::HTTP_2 {
match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2) {
Ok(pool_key) => {
if let Some(conn) = self.pool.get_http2(&pool_key).await {
SendRequest::Http2(conn)
} else {
self.connect(req.uri(), req.version(), Some(pool_key))
.await?
}
}
Err(err) => {
tracing::warn!(
"ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool",
err
);
self.connect(req.uri(), req.version(), None).await?
}
}
} else {
match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) {
Ok(pool_key) => {
if let Some(conn) = self.pool.take_http1(&pool_key).await {
SendRequest::Http1(conn)
} else {
self.connect(req.uri(), req.version(), Some(pool_key))
.await?
}
}
Err(err) => {
tracing::warn!(
"ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool",
err
);
self.connect(req.uri(), req.version(), None).await?
}
}
};

let (req_parts, req_body) = req.into_parts();

let res = send_request
.send_request(Request::from_parts(req_parts.clone(), req_body))
.await?;
let boxed_req = Request::from_parts(req_parts.clone(), to_boxed_body(req_body));

let res = send_request.send_request(boxed_req).await?;

if res.status() == StatusCode::SWITCHING_PROTOCOLS {
let (res_parts, res_body) = res.into_parts();
Expand Down Expand Up @@ -221,36 +358,41 @@ impl DefaultClient {

Ok((Response::from_parts(res_parts, res_body), upgrade))
} else {
match send_request {
SendRequest::Http1(sender) => {
if let Ok(pool_key) =
ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)
{
self.pool.put_http1(pool_key, sender).await;
} else {
// If we couldn't build a pool key, skip pooling.
}
}
SendRequest::Http2(_) => {
// For HTTP/2 the pool retains a shared sender; no action needed.
}
}
Ok((res, None))
}
}

async fn connect<B>(&self, uri: &Uri, http_version: Version) -> Result<SendRequest<B>, Error>
where
B: Body + Unpin + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let host = uri
.host()
.ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?;
let port =
uri.port_u16()
.unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
443
} else {
80
});
async fn connect(
&self,
uri: &Uri,
http_version: Version,
key: Option<ConnectionKey>,
) -> Result<SendRequest, Error> {
let (host, port, is_tls) = host_port(uri)?;

let tcp = TcpStream::connect((host, port)).await?;
let tcp = TcpStream::connect((host.as_str(), port)).await?;
// This is actually needed to some servers
let _ = tcp.set_nodelay(true);

if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
if is_tls {
#[cfg(feature = "native-tls-client")]
let tls = self
.tls_connector(http_version)
.connect(host, tcp)
.connect(&host, tcp)
.await
.map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
#[cfg(feature = "rustls-client")]
Expand Down Expand Up @@ -284,6 +426,14 @@ impl DefaultClient {

tokio::spawn(conn);

if let Some(ref k) = key
&& matches!(k.protocol, ConnectionProtocol::Http2)
{
self.pool
.insert_http2_if_absent(k.clone(), sender.clone())
.await;
}

Ok(SendRequest::Http2(sender))
} else {
let (sender, conn) = client::conn::http1::Builder::new()
Expand All @@ -310,18 +460,15 @@ impl DefaultClient {
}
}

enum SendRequest<B> {
Http1(hyper::client::conn::http1::SendRequest<B>),
Http2(hyper::client::conn::http2::SendRequest<B>),
enum SendRequest {
Http1(Http1Sender),
Http2(Http2Sender),
}

impl<B> SendRequest<B>
where
B: Body + 'static,
{
impl SendRequest {
async fn send_request(
&mut self,
mut req: Request<B>,
mut req: Request<PooledBody>,
) -> Result<Response<Incoming>, hyper::Error> {
match self {
SendRequest::Http1(sender) => {
Expand Down Expand Up @@ -357,13 +504,13 @@ where
}
}

impl<B> SendRequest<B> {
impl SendRequest {
#[allow(dead_code)]
// TODO: connection pooling
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), hyper::Error>> {
match self {
SendRequest::Http1(sender) => sender.poll_ready(cx),
SendRequest::Http2(sender) => sender.poll_ready(cx),
SendRequest::Http2(_sender) => Poll::Ready(Ok(())),
}
}
}
Expand All @@ -375,3 +522,24 @@ fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::Inv
*req.uri_mut() = Uri::from_parts(parts)?;
Ok(())
}

fn to_boxed_body<B>(body: B) -> PooledBody
where
B: Body<Data = Bytes> + Send + Sync + 'static,
B::Data: Send + Buf,
B::Error: Into<DynError>,
{
body.map_err(|err| err.into()).boxed()
}

fn host_port(uri: &Uri) -> Result<(String, u16, bool), Error> {
let host = uri
.host()
.ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?
.to_string();
let is_tls = uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS);
let port = uri.port_u16().unwrap_or(if is_tls { 443 } else { 80 });
Ok((host, port, is_tls))
}

impl DefaultClient {}
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ where
<S::ResBody as Body>::Error: Into<Box<dyn StdError + Send + Sync>>,
S::Future: Send,
{
service_fn(move |req| {
service_fn(move |mut req| {
let proxy = proxy.clone();
let mut service = service.clone();

Expand All @@ -151,6 +151,7 @@ where
};

tokio::spawn(async move {
let remote_addr: Option<RemoteAddr> = req.extensions_mut().remove();
let client = match hyper::upgrade::on(req).await {
Ok(client) => client,
Err(err) => {
Expand Down Expand Up @@ -194,6 +195,9 @@ where
let mut service = service.clone();

async move {
if let Some(remote_addr) = remote_addr {
req.extensions_mut().insert(remote_addr);
}
inject_authority(&mut req, connect_authority.clone());
service.call(req).await
}
Expand Down
Loading