From 78c49ba5559d63f2d4c344d444ba7a609134c5c0 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 24 Jan 2026 16:19:57 +0900 Subject: [PATCH 1/6] conn pol --- src/default_client.rs | 281 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 240 insertions(+), 41 deletions(-) diff --git a/src/default_client.rs b/src/default_client.rs index 0255ac2..9072010 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -1,14 +1,20 @@ #![cfg(any(feature = "native-tls-client", feature = "rustls-client"))] -use bytes::Bytes; -use http_body_util::Empty; +use bytes::{Buf, Bytes}; +use http_body_util::{BodyExt, Empty, combinators::BoxBody}; use hyper::{ Request, Response, StatusCode, Uri, Version, body::{Body, Incoming}, client, header, }; use hyper_util::rt::{TokioExecutor, TokioIo}; -use std::task::{Context, Poll}; +use std::{ + collections::HashMap, + future::poll_fn, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::sync::Mutex; use tokio::{net::TcpStream, task::JoinHandle}; #[cfg(all(feature = "native-tls-client", feature = "rustls-client"))] @@ -55,6 +61,141 @@ pub struct Upgraded { /// A socket to Server pub server: TokioIo, } + +type DynError = Box; +type PooledBody = BoxBody; +type Http1Sender = hyper::client::conn::http1::SendRequest; +type Http2Sender = hyper::client::conn::http2::SendRequest; + +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +enum ConnectionProtocol { + Http1, + Http2, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct ConnectionKey { + host: String, + port: u16, + is_tls: bool, + protocol: ConnectionProtocol, +} + +impl ConnectionKey { + fn new(host: String, port: u16, is_tls: bool, protocol: ConnectionProtocol) -> Self { + Self { + host, + port, + is_tls, + protocol, + } + } + + fn from_uri(uri: &Uri, protocol: ConnectionProtocol) -> Result { + let (host, port, is_tls) = host_port(uri)?; + Ok(ConnectionKey::new(host, port, is_tls, protocol)) + } +} + +#[derive(Clone, Default)] +struct ConnectionPool { + http1: Arc>>>, + http2: Arc>>>>, +} + +impl ConnectionPool { + async fn take_http1(&self, key: &ConnectionKey) -> Option { + let mut guard = self.http1.lock().await; + let entry = guard.get_mut(key)?; + while let Some(mut conn) = entry.pop() { + if sender_alive_http1(&mut conn).await { + return Some(conn); + } + } + if entry.is_empty() { + guard.remove(key); + } + None + } + + async fn put_http1(&self, key: ConnectionKey, sender: Http1Sender) { + let mut guard = self.http1.lock().await; + guard.entry(key).or_default().push(sender); + } + + async fn get_http2(&self, key: &ConnectionKey) -> Option>> { + let maybe = { + let guard = self.http2.lock().await; + guard.get(key).cloned() + }; + + if let Some(sender) = maybe { + let alive = { + let mut guard = sender.lock().await; + sender_alive_http2(&mut guard).await + }; + + if alive { + Some(sender) + } else { + let mut map = self.http2.lock().await; + map.remove(key); + None + } + } else { + None + } + } + + async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Arc>) { + let mut guard = self.http2.lock().await; + guard.entry(key).or_insert(sender); + } + + async fn has_http1(&self, key: &ConnectionKey) -> bool { + let mut guard = self.http1.lock().await; + if let Some(vec) = guard.get_mut(key) { + while let Some(mut conn) = vec.pop() { + if sender_alive_http1(&mut conn).await { + vec.push(conn); + return true; + } + } + guard.remove(key); + } + false + } + + async fn has_http2(&self, key: &ConnectionKey) -> bool { + let maybe = { + let guard = self.http2.lock().await; + guard.get(key).cloned() + }; + + if let Some(sender) = maybe { + let alive = { + let mut guard = sender.lock().await; + sender_alive_http2(&mut guard).await + }; + if !alive { + let mut map = self.http2.lock().await; + map.remove(key); + } + alive + } else { + false + } + } +} + +async fn sender_alive_http1(sender: &mut Http1Sender) -> bool { + poll_fn(|cx| sender.poll_ready(cx)).await.is_ok() +} + +async fn sender_alive_http2(sender: &mut Http2Sender) -> bool { + poll_fn(|cx| sender.poll_ready(cx)).await.is_ok() +} + #[derive(Clone)] /// Default HTTP client for this crate pub struct DefaultClient { @@ -71,6 +212,8 @@ pub struct DefaultClient { /// If true, send_request will returns an Upgraded struct when the response is an upgrade /// If false, send_request never returns an Upgraded struct and just copy bidirectional when the response is an upgrade pub with_upgrades: bool, + + pool: ConnectionPool, } impl Default for DefaultClient { fn default() -> Self { @@ -102,6 +245,7 @@ impl DefaultClient { tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn), tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2), with_upgrades: false, + pool: ConnectionPool::default(), }) } @@ -135,6 +279,7 @@ impl DefaultClient { tls_connector_alpn_h2, )), with_upgrades: false, + pool: ConnectionPool::default(), }) } @@ -145,6 +290,21 @@ impl DefaultClient { self } + /// Check if a cached connection exists for the given URI and HTTP version. + pub async fn has_cached_connection(&self, uri: &Uri, version: Version) -> Result { + let protocol = if version == Version::HTTP_2 { + ConnectionProtocol::Http2 + } else { + ConnectionProtocol::Http1 + }; + let key = ConnectionKey::from_uri(uri, protocol)?; + let available = match protocol { + ConnectionProtocol::Http1 => self.pool.has_http1(&key).await, + ConnectionProtocol::Http2 => self.pool.has_http2(&key).await, + }; + Ok(available) + } + #[cfg(feature = "native-tls-client")] fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector { match http_version { @@ -175,17 +335,32 @@ impl DefaultClient { Error, > where - B: Body + Unpin + Send + 'static, - B::Data: Send, - B::Error: Into>, + B: Body + Send + Sync + 'static, + B::Data: Send + Buf, + B::Error: Into, { - let mut send_request = self.connect(req.uri(), req.version()).await?; + let target_uri = req.uri().clone(); + let mut send_request = if req.version() == Version::HTTP_2 { + let pool_key = ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2)?; + if let Some(conn) = self.pool.get_http2(&pool_key).await { + SendRequest::Http2(conn) + } else { + self.connect(req.uri(), req.version(), pool_key).await? + } + } else { + let pool_key = ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)?; + if let Some(conn) = self.pool.take_http1(&pool_key).await { + SendRequest::Http1(conn) + } else { + self.connect(req.uri(), req.version(), pool_key).await? + } + }; let (req_parts, req_body) = req.into_parts(); - let res = send_request - .send_request(Request::from_parts(req_parts.clone(), req_body)) - .await?; + let boxed_req = Request::from_parts(req_parts.clone(), to_boxed_body(req_body)); + + let res = send_request.send_request(boxed_req).await?; if res.status() == StatusCode::SWITCHING_PROTOCOLS { let (res_parts, res_body) = res.into_parts(); @@ -221,36 +396,36 @@ impl DefaultClient { Ok((Response::from_parts(res_parts, res_body), upgrade)) } else { + match send_request { + SendRequest::Http1(sender) => { + let pool_key = ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)?; + self.pool.put_http1(pool_key, sender).await; + } + SendRequest::Http2(_) => { + // For HTTP/2 the pool retains a shared sender; no action needed. + } + } Ok((res, None)) } } - async fn connect(&self, uri: &Uri, http_version: Version) -> Result, Error> - where - B: Body + Unpin + Send + 'static, - B::Data: Send, - B::Error: Into>, - { - let host = uri - .host() - .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?; - let port = - uri.port_u16() - .unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) { - 443 - } else { - 80 - }); + async fn connect( + &self, + uri: &Uri, + http_version: Version, + key: ConnectionKey, + ) -> Result { + let (host, port, is_tls) = host_port(uri)?; - let tcp = TcpStream::connect((host, port)).await?; + let tcp = TcpStream::connect((host.as_str(), port)).await?; // This is actually needed to some servers let _ = tcp.set_nodelay(true); - if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) { + if is_tls { #[cfg(feature = "native-tls-client")] let tls = self .tls_connector(http_version) - .connect(host, tcp) + .connect(&host, tcp) .await .map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?; #[cfg(feature = "rustls-client")] @@ -284,7 +459,12 @@ impl DefaultClient { tokio::spawn(conn); - Ok(SendRequest::Http2(sender)) + let shared = Arc::new(Mutex::new(sender)); + if matches!(key.protocol, ConnectionProtocol::Http2) { + self.pool.insert_http2_if_absent(key, shared.clone()).await; + } + + Ok(SendRequest::Http2(shared)) } else { let (sender, conn) = client::conn::http1::Builder::new() .preserve_header_case(true) @@ -310,18 +490,15 @@ impl DefaultClient { } } -enum SendRequest { - Http1(hyper::client::conn::http1::SendRequest), - Http2(hyper::client::conn::http2::SendRequest), +enum SendRequest { + Http1(Http1Sender), + Http2(Arc>), } -impl SendRequest -where - B: Body + 'static, -{ +impl SendRequest { async fn send_request( &mut self, - mut req: Request, + mut req: Request, ) -> Result, hyper::Error> { match self { SendRequest::Http1(sender) => { @@ -351,19 +528,20 @@ where if req.version() != hyper::Version::HTTP_2 { req.headers_mut().remove(header::HOST); } - sender.send_request(req).await + let mut guard = sender.lock().await; + guard.send_request(req).await } } } } -impl SendRequest { +impl SendRequest { #[allow(dead_code)] // TODO: connection pooling fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { match self { SendRequest::Http1(sender) => sender.poll_ready(cx), - SendRequest::Http2(sender) => sender.poll_ready(cx), + SendRequest::Http2(_sender) => Poll::Ready(Ok(())), } } } @@ -375,3 +553,24 @@ fn remove_authority(req: &mut Request) -> Result<(), hyper::http::uri::Inv *req.uri_mut() = Uri::from_parts(parts)?; Ok(()) } + +fn to_boxed_body(body: B) -> PooledBody +where + B: Body + Send + Sync + 'static, + B::Data: Send + Buf, + B::Error: Into, +{ + body.map_err(|err| err.into()).boxed() +} + +fn host_port(uri: &Uri) -> Result<(String, u16, bool), Error> { + let host = uri + .host() + .ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))? + .to_string(); + let is_tls = uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS); + let port = uri.port_u16().unwrap_or(if is_tls { 443 } else { 80 }); + Ok((host, port, is_tls)) +} + +impl DefaultClient {} From 7c8628e27690309521565b315a63225012c44faf Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 24 Jan 2026 16:27:44 +0900 Subject: [PATCH 2/6] refactor --- src/default_client.rs | 88 +++++++------------------------------------ 1 file changed, 13 insertions(+), 75 deletions(-) diff --git a/src/default_client.rs b/src/default_client.rs index 9072010..43c728a 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -100,7 +100,7 @@ impl ConnectionKey { #[derive(Clone, Default)] struct ConnectionPool { http1: Arc>>>, - http2: Arc>>>>, + http2: Arc>>, } impl ConnectionPool { @@ -123,69 +123,24 @@ impl ConnectionPool { guard.entry(key).or_default().push(sender); } - async fn get_http2(&self, key: &ConnectionKey) -> Option>> { - let maybe = { - let guard = self.http2.lock().await; - guard.get(key).cloned() - }; + async fn get_http2(&self, key: &ConnectionKey) -> Option { + let mut guard = self.http2.lock().await; + let mut sender = guard.get(key).cloned()?; - if let Some(sender) = maybe { - let alive = { - let mut guard = sender.lock().await; - sender_alive_http2(&mut guard).await - }; + let alive = sender_alive_http2(&mut sender).await; - if alive { - Some(sender) - } else { - let mut map = self.http2.lock().await; - map.remove(key); - None - } + if alive { + Some(sender) } else { + guard.remove(key); None } } - async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Arc>) { + async fn insert_http2_if_absent(&self, key: ConnectionKey, sender: Http2Sender) { let mut guard = self.http2.lock().await; guard.entry(key).or_insert(sender); } - - async fn has_http1(&self, key: &ConnectionKey) -> bool { - let mut guard = self.http1.lock().await; - if let Some(vec) = guard.get_mut(key) { - while let Some(mut conn) = vec.pop() { - if sender_alive_http1(&mut conn).await { - vec.push(conn); - return true; - } - } - guard.remove(key); - } - false - } - - async fn has_http2(&self, key: &ConnectionKey) -> bool { - let maybe = { - let guard = self.http2.lock().await; - guard.get(key).cloned() - }; - - if let Some(sender) = maybe { - let alive = { - let mut guard = sender.lock().await; - sender_alive_http2(&mut guard).await - }; - if !alive { - let mut map = self.http2.lock().await; - map.remove(key); - } - alive - } else { - false - } - } } async fn sender_alive_http1(sender: &mut Http1Sender) -> bool { @@ -290,21 +245,6 @@ impl DefaultClient { self } - /// Check if a cached connection exists for the given URI and HTTP version. - pub async fn has_cached_connection(&self, uri: &Uri, version: Version) -> Result { - let protocol = if version == Version::HTTP_2 { - ConnectionProtocol::Http2 - } else { - ConnectionProtocol::Http1 - }; - let key = ConnectionKey::from_uri(uri, protocol)?; - let available = match protocol { - ConnectionProtocol::Http1 => self.pool.has_http1(&key).await, - ConnectionProtocol::Http2 => self.pool.has_http2(&key).await, - }; - Ok(available) - } - #[cfg(feature = "native-tls-client")] fn tls_connector(&self, http_version: Version) -> &tokio_native_tls::TlsConnector { match http_version { @@ -459,12 +399,11 @@ impl DefaultClient { tokio::spawn(conn); - let shared = Arc::new(Mutex::new(sender)); if matches!(key.protocol, ConnectionProtocol::Http2) { - self.pool.insert_http2_if_absent(key, shared.clone()).await; + self.pool.insert_http2_if_absent(key, sender.clone()).await; } - Ok(SendRequest::Http2(shared)) + Ok(SendRequest::Http2(sender)) } else { let (sender, conn) = client::conn::http1::Builder::new() .preserve_header_case(true) @@ -492,7 +431,7 @@ impl DefaultClient { enum SendRequest { Http1(Http1Sender), - Http2(Arc>), + Http2(Http2Sender), } impl SendRequest { @@ -528,8 +467,7 @@ impl SendRequest { if req.version() != hyper::Version::HTTP_2 { req.headers_mut().remove(header::HOST); } - let mut guard = sender.lock().await; - guard.send_request(req).await + sender.send_request(req).await } } } From 3665f7d538d5deffb657e28dade54c98a3f6eaa5 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 24 Jan 2026 16:40:16 +0900 Subject: [PATCH 3/6] fix --- src/lib.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib.rs b/src/lib.rs index 33c1538..211a877 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -134,7 +134,7 @@ where ::Error: Into>, S::Future: Send, { - service_fn(move |req| { + service_fn(move |mut req| { let proxy = proxy.clone(); let mut service = service.clone(); @@ -151,6 +151,7 @@ where }; tokio::spawn(async move { + let remote_addr: Option = req.extensions_mut().remove(); let client = match hyper::upgrade::on(req).await { Ok(client) => client, Err(err) => { @@ -194,6 +195,9 @@ where let mut service = service.clone(); async move { + if let Some(remote_addr) = remote_addr { + req.extensions_mut().insert(remote_addr); + } inject_authority(&mut req, connect_authority.clone()); service.call(req).await } From 9be4e4368651e37499b6f32c767d5c20813c6af7 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 24 Jan 2026 16:44:45 +0900 Subject: [PATCH 4/6] update --- src/default_client.rs | 61 ++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/src/default_client.rs b/src/default_client.rs index 43c728a..40b357f 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -281,18 +281,40 @@ impl DefaultClient { { let target_uri = req.uri().clone(); let mut send_request = if req.version() == Version::HTTP_2 { - let pool_key = ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2)?; - if let Some(conn) = self.pool.get_http2(&pool_key).await { - SendRequest::Http2(conn) - } else { - self.connect(req.uri(), req.version(), pool_key).await? + match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http2) { + Ok(pool_key) => { + if let Some(conn) = self.pool.get_http2(&pool_key).await { + SendRequest::Http2(conn) + } else { + self.connect(req.uri(), req.version(), Some(pool_key)) + .await? + } + } + Err(err) => { + tracing::warn!( + "ConnectionKey::from_uri failed for HTTP/2 ({}): continuing without pool", + err + ); + self.connect(req.uri(), req.version(), None).await? + } } } else { - let pool_key = ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)?; - if let Some(conn) = self.pool.take_http1(&pool_key).await { - SendRequest::Http1(conn) - } else { - self.connect(req.uri(), req.version(), pool_key).await? + match ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) { + Ok(pool_key) => { + if let Some(conn) = self.pool.take_http1(&pool_key).await { + SendRequest::Http1(conn) + } else { + self.connect(req.uri(), req.version(), Some(pool_key)) + .await? + } + } + Err(err) => { + tracing::warn!( + "ConnectionKey::from_uri failed for HTTP/1 ({}): continuing without pool", + err + ); + self.connect(req.uri(), req.version(), None).await? + } } }; @@ -338,8 +360,13 @@ impl DefaultClient { } else { match send_request { SendRequest::Http1(sender) => { - let pool_key = ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1)?; - self.pool.put_http1(pool_key, sender).await; + if let Ok(pool_key) = + ConnectionKey::from_uri(&target_uri, ConnectionProtocol::Http1) + { + self.pool.put_http1(pool_key, sender).await; + } else { + // If we couldn't build a pool key, skip pooling. + } } SendRequest::Http2(_) => { // For HTTP/2 the pool retains a shared sender; no action needed. @@ -353,7 +380,7 @@ impl DefaultClient { &self, uri: &Uri, http_version: Version, - key: ConnectionKey, + key: Option, ) -> Result { let (host, port, is_tls) = host_port(uri)?; @@ -399,8 +426,12 @@ impl DefaultClient { tokio::spawn(conn); - if matches!(key.protocol, ConnectionProtocol::Http2) { - self.pool.insert_http2_if_absent(key, sender.clone()).await; + if let Some(ref k) = key { + if matches!(k.protocol, ConnectionProtocol::Http2) { + self.pool + .insert_http2_if_absent(k.clone(), sender.clone()) + .await; + } } Ok(SendRequest::Http2(sender)) From 324bab239a23def2b3d2f5b0fb0528742a2739cb Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 24 Jan 2026 16:53:42 +0900 Subject: [PATCH 5/6] tests --- tests/test.rs | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/tests/test.rs b/tests/test.rs index bf35b5a..e708e8e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -11,7 +11,7 @@ use axum::{ }; use bytes::Bytes; use futures::stream; -use http_mitm_proxy::{DefaultClient, MitmProxy}; +use http_mitm_proxy::{DefaultClient, MitmProxy, RemoteAddr}; use hyper::{ Uri, body::{Body, Incoming}, @@ -137,7 +137,13 @@ async fn test_simple_http() { app, service_fn(move |req| { let proxy_client = proxy_client.clone(); - async move { proxy_client.send_request(req).await.map(|t| t.0) } + async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); + proxy_client.send_request(req).await.map(|t| t.0) + } }), ) .await; @@ -173,6 +179,10 @@ async fn test_modify_http() { service_fn(move |mut req| { let proxy_client = proxy_client.clone(); async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); req.headers_mut() .insert("X-test", "modified".parse().unwrap()); proxy_client.send_request(req).await.map(|t| t.0) @@ -209,7 +219,13 @@ async fn test_sse_http() { app, service_fn(move |req| { let proxy_client = proxy_client.clone(); - async move { proxy_client.send_request(req).await.map(|t| t.0) } + async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); + proxy_client.send_request(req).await.map(|t| t.0) + } }), ) .await; @@ -258,6 +274,10 @@ async fn test_simple_https() { service_fn(move |mut req| { let proxy_client = proxy_client.clone(); async move { + assert!( + req.extensions().get::().is_some(), + "RemoteAddr missing" + ); let mut parts = req.uri().clone().into_parts(); parts.scheme = Some(hyper::http::uri::Scheme::HTTP); From 6334119289587144449807390d123c75b18c8e1b Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 24 Jan 2026 16:56:47 +0900 Subject: [PATCH 6/6] clippy --- src/default_client.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/default_client.rs b/src/default_client.rs index 40b357f..bbe2995 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -426,12 +426,12 @@ impl DefaultClient { tokio::spawn(conn); - if let Some(ref k) = key { - if matches!(k.protocol, ConnectionProtocol::Http2) { - self.pool - .insert_http2_if_absent(k.clone(), sender.clone()) - .await; - } + if let Some(ref k) = key + && matches!(k.protocol, ConnectionProtocol::Http2) + { + self.pool + .insert_http2_if_absent(k.clone(), sender.clone()) + .await; } Ok(SendRequest::Http2(sender))