diff --git a/src/default_client.rs b/src/default_client.rs index 0255ac2..bbe2995 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -1,14 +1,20 @@ #![cfg(any(feature = "native-tls-client", feature = "rustls-client"))] -use bytes::Bytes; -use http_body_util::Empty; +use bytes::{Buf, Bytes}; +use http_body_util::{BodyExt, Empty, combinators::BoxBody}; use hyper::{ Request, Response, StatusCode, Uri, Version, body::{Body, Incoming}, client, header, }; use hyper_util::rt::{TokioExecutor, TokioIo}; -use std::task::{Context, Poll}; +use std::{ + collections::HashMap, + future::poll_fn, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::sync::Mutex; use tokio::{net::TcpStream, task::JoinHandle}; #[cfg(all(feature = "native-tls-client", feature = "rustls-client"))] @@ -55,6 +61,96 @@ pub struct Upgraded { /// A socket to Server pub server: TokioIo, } + +type DynError = Box; +type PooledBody = BoxBody; +type Http1Sender = hyper::client::conn::http1::SendRequest; +type Http2Sender = hyper::client::conn::http2::SendRequest; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +enum ConnectionProtocol { + Http1, + Http2, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct ConnectionKey { + host: String, + port: u16, + is_tls: bool, + protocol: ConnectionProtocol, +} + +impl ConnectionKey { + fn new(host: String, port: u16, is_tls: bool, protocol: ConnectionProtocol) -> Self { + Self { + host, + port, + is_tls, + protocol, + } + } + + fn from_uri(uri: &Uri, protocol: ConnectionProtocol) -> Result { + let (host, port, is_tls) = host_port(uri)?; + Ok(ConnectionKey::new(host, port, is_tls, protocol)) + } +} + +#[derive(Clone, Default)] +struct ConnectionPool { + http1: Arc>>>, + http2: Arc>>, +} + +impl ConnectionPool { + async fn take_http1(&self, key: &ConnectionKey) -> Option { + let mut guard = self.http1.lock().await; + let entry = guard.get_mut(key)?; + while let Some(mut conn) = entry.pop() { + if sender_alive_http1(&mut conn).await { + return Some(conn); + } + } + if entry.is_empty() { + guard.remove(key); + } + None + } + + async fn put_http1(&self, key: ConnectionKey, sender: Http1Sender) { + let mut guard = self.http1.lock().await; + guard.entry(key).or_default().push(sender); + } + + async fn get_http2(&self, key: &ConnectionKey) -> Option { + let mut guard = self.http2.lock().await; + let mut sender = guard.get(key).cloned()?; + + let alive = sender_alive_http2(&mut sender).await; + + if alive { + Some(sender) + } else { + guard.remove(key); + None + } + } + + async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Http2Sender) { + let mut guard = self.http2.lock().await; + guard.entry(key).or_insert(sender); + } +} + +async fn sender_alive_http1(sender: &mut Http1Sender) -> bool { + poll_fn(|cx| sender.poll_ready(cx)).await.is_ok() +} + +async fn sender_alive_http2(sender: &mut Http2Sender) -> bool { + poll_fn(|cx| sender.poll_ready(cx)).await.is_ok() +} + #[derive(Clone)] /// Default HTTP client for this crate pub struct DefaultClient { @@ -71,6 +167,8 @@ pub struct DefaultClient { /// If true, send_request will returns an Upgraded struct when the response is an upgrade /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade pub with_upgrades: bool, + + pool: ConnectionPool, } impl Default for DefaultClient { fn default() -> Self { @@ -102,6 +200,7 @@ impl DefaultClient { tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn), tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2), with_upgrades: false, + pool: ConnectionPool::default(), }) } @@ -135,6 +234,7 @@ impl DefaultClient { tls_connector_alpn_h2, )), with_upgrades: false, + pool: ConnectionPool::default(), }) } @@ -175,17 +275,54 @@ impl DefaultClient { Error, > where - B: Body + Unpin + Send + 'static, - B::Data: Send, - B::Error: Into>, + B: Body + Send + Sync + 'static, + B::Data: Send + Buf, + B::Error: Into, { - let mut send_request = self.connect(req.uri(), req.version()).await?; + let target_uri = req.uri().clone(); + let mut send_request = if req.version() == Version::HTTP_2 { + match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2) { + Ok(pool_key) => { + if let Some(conn) = self.pool.get_http2(&pool_key).await { + SendRequest::Http2(conn) + } else { + self.connect(req.uri(), req.version(), Some(pool_key)) + .await? + } + } + Err(err) => { + tracing::warn!( + "ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool", + err + ); + self.connect(req.uri(), req.version(), None).await? + } + } + } else { + match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) { + Ok(pool_key) => { + if let Some(conn) = self.pool.take_http1(&pool_key).await { + SendRequest::Http1(conn) + } else { + self.connect(req.uri(), req.version(), Some(pool_key)) + .await? + } + } + Err(err) => { + tracing::warn!( + "ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool", + err + ); + self.connect(req.uri(), req.version(), None).await? + } + } + }; let (req_parts, req_body) = req.into_parts(); - let res = send_request - .send_request(Request::from_parts(req_parts.clone(), req_body)) - .await?; + let boxed_req = Request::from_parts(req_parts.clone(), to_boxed_body(req_body)); + + let res = send_request.send_request(boxed_req).await?; if res.status() == StatusCode::SWITCHING_PROTOCOLS { let (res_parts, res_body) = res.into_parts(); @@ -221,36 +358,41 @@ impl DefaultClient { Ok((Response::from_parts(res_parts, res_body), upgrade)) } else { + match send_request { + SendRequest::Http1(sender) => { + if let Ok(pool_key) = + ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) + { + self.pool.put_http1(pool_key, sender).await; + } else { + // If we couldn't build a pool key, skip pooling. + } + } + SendRequest::Http2(_) => { + // For HTTP/2 the pool retains a shared sender; no action needed. + } + } Ok((res, None)) } } - async fn connect(&self, uri: &Uri, http_version: Version) -> Result, Error> - where - B: Body + Unpin + Send + 'static, - B::Data: Send, - B::Error: Into>, - { - let host = uri - .host() - .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?; - let port = - uri.port_u16() - .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) { - 443 - } else { - 80 - }); + async fn connect( + &self, + uri: &Uri, + http_version: Version, + key: Option, + ) -> Result { + let (host, port, is_tls) = host_port(uri)?; - let tcp = TcpStream::connect((host, port)).await?; + let tcp = TcpStream::connect((host.as_str(), port)).await?; // This is actually needed to some servers let _ = tcp.set_nodelay(true); - if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) { + if is_tls { #[cfg(feature = "native-tls-client")] let tls = self .tls_connector(http_version) - .connect(host, tcp) + .connect(&host, tcp) .await .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?; #[cfg(feature = "rustls-client")] @@ -284,6 +426,14 @@ impl DefaultClient { tokio::spawn(conn); + if let Some(ref k) = key + && matches!(k.protocol, ConnectionProtocol::Http2) + { + self.pool + .insert_http2_if_absent(k.clone(), sender.clone()) + .await; + } + Ok(SendRequest::Http2(sender)) } else { let (sender, conn) = client::conn::http1::Builder::new() @@ -310,18 +460,15 @@ impl DefaultClient { } } -enum SendRequest { - Http1(hyper::client::conn::http1::SendRequest), - Http2(hyper::client::conn::http2::SendRequest), +enum SendRequest { + Http1(Http1Sender), + Http2(Http2Sender), } -impl SendRequest -where - B: Body + 'static, -{ +impl SendRequest { async fn send_request( &mut self, - mut req: Request, + mut req: Request, ) -> Result, hyper::Error> { match self { SendRequest::Http1(sender) => { @@ -357,13 +504,13 @@ where } } -impl SendRequest { +impl SendRequest { #[allow(dead_code)] // TODO: connection pooling fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match self { SendRequest::Http1(sender) => sender.poll_ready(cx), - SendRequest::Http2(sender) => sender.poll_ready(cx), + SendRequest::Http2(_sender) => Poll::Ready(Ok(())), } } } @@ -375,3 +522,24 @@ fn remove_authority(req: &mut Request) -> Result<(), hyper::http::uri::Inv *req.uri_mut() = Uri::from_parts(parts)?; Ok(()) } + +fn to_boxed_body(body: B) -> PooledBody +where + B: Body + Send + Sync + 'static, + B::Data: Send + Buf, + B::Error: Into, +{ + body.map_err(|err| err.into()).boxed() +} + +fn host_port(uri: &Uri) -> Result<(String, u16, bool), Error> { + let host = uri + .host() + .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))? + .to_string(); + let is_tls = uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS); + let port = uri.port_u16().unwrap_or(if is_tls { 443 } else { 80 }); + Ok((host, port, is_tls)) +} + +impl DefaultClient {} diff --git a/src/lib.rs b/src/lib.rs index 33c1538..211a877 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ where ::Error: Into>, S::Future: Send, { - service_fn(move |req| { + service_fn(move |mut req| { let proxy = proxy.clone(); let mut service = service.clone(); @@ -151,6 +151,7 @@ where }; tokio::spawn(async move { + let remote_addr: Option = req.extensions_mut().remove(); let client = match hyper::upgrade::on(req).await { Ok(client) => client, Err(err) => { @@ -194,6 +195,9 @@ where let mut service = service.clone(); async move { + if let Some(remote_addr) = remote_addr { + req.extensions_mut().insert(remote_addr); + } inject_authority(&mut req, connect_authority.clone()); service.call(req).await } diff --git a/tests/test.rs b/tests/test.rs index bf35b5a..e708e8e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -11,7 +11,7 @@ use axum::{ }; use bytes::Bytes; use futures::stream; -use http_mitm_proxy::{DefaultClient, MitmProxy}; +use http_mitm_proxy::{DefaultClient, MitmProxy, RemoteAddr}; use hyper::{ Uri, body::{Body, Incoming}, @@ -137,7 +137,13 @@ async fn test_simple_http() { app, service_fn(move |req| { let proxy_client = proxy_client.clone(); - async move { proxy_client.send_request(req).await.map(|t| t.0) } + async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); + proxy_client.send_request(req).await.map(|t| t.0) + } }), ) .await; @@ -173,6 +179,10 @@ async fn test_modify_http() { service_fn(move |mut req| { let proxy_client = proxy_client.clone(); async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); req.headers_mut() .insert("X-test", "modified".parse().unwrap()); proxy_client.send_request(req).await.map(|t| t.0) @@ -209,7 +219,13 @@ async fn test_sse_http() { app, service_fn(move |req| { let proxy_client = proxy_client.clone(); - async move { proxy_client.send_request(req).await.map(|t| t.0) } + async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); + proxy_client.send_request(req).await.map(|t| t.0) + } }), ) .await; @@ -258,6 +274,10 @@ async fn test_simple_https() { service_fn(move |mut req| { let proxy_client = proxy_client.clone(); async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); let mut parts = req.uri().clone().into_parts(); parts.scheme = Some(hyper::http::uri::Scheme::HTTP);