Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions src/default_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,33 @@ compile_error!(
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("{0} doesn't have an valid host")]
InvalidHost(Uri),
InvalidHost(Box<Uri>),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error(transparent)]
HyperError(#[from] hyper::Error),
#[error("Failed to connect to {0}, {1}")]
ConnectError(Uri, hyper::Error),
ConnectError(Box<Uri>, hyper::Error),

#[cfg(feature = "native-tls-client")]
#[error("Failed to connect with TLS to {0}, {1}")]
TlsConnectError(Uri, native_tls::Error),
TlsConnectError(Box<Uri>, native_tls::Error),
#[cfg(feature = "native-tls-client")]
#[error(transparent)]
NativeTlsError(#[from] tokio_native_tls::native_tls::Error),

#[cfg(feature = "rustls-client")]
#[error("Failed to connect with TLS to {0}, {1}")]
TlsConnectError(Uri, std::io::Error),
TlsConnectError(Box<Uri>, std::io::Error),

#[error("Failed to parse URI: {0}")]
UriParsingError(#[from] hyper::http::uri::InvalidUri),

#[error("Failed to parse URI parts: {0}")]
UriPartsError(#[from] hyper::http::uri::InvalidUriParts),

#[error("TLS connector initialization failed: {0}")]
TlsConnectorError(String),
}

/// Upgraded connections
Expand Down Expand Up @@ -72,21 +81,39 @@ impl Default for DefaultClient {
impl DefaultClient {
#[cfg(feature = "native-tls-client")]
pub fn new() -> Self {
let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().unwrap();
Self::try_new().unwrap_or_else(|err| {
panic!("Failed to create DefaultClient: {}", err);
})
}

#[cfg(feature = "native-tls-client")]
pub fn try_new() -> Result<Self, Error> {
let tls_connector_no_alpn = native_tls::TlsConnector::builder().build().map_err(|e| {
Error::TlsConnectorError(format!("Failed to build no-ALPN connector: {}", e))
})?;
let tls_connector_alpn_h2 = native_tls::TlsConnector::builder()
.request_alpns(&["h2", "http/1.1"])
.build()
.unwrap();
.map_err(|e| {
Error::TlsConnectorError(format!("Failed to build ALPN-H2 connector: {}", e))
})?;

Self {
Ok(Self {
tls_connector_no_alpn: tokio_native_tls::TlsConnector::from(tls_connector_no_alpn),
tls_connector_alpn_h2: tokio_native_tls::TlsConnector::from(tls_connector_alpn_h2),
with_upgrades: false,
}
})
}

#[cfg(feature = "rustls-client")]
pub fn new() -> Self {
Self::try_new().unwrap_or_else(|err| {
panic!("Failed to create DefaultClient: {}", err);
})
}

#[cfg(feature = "rustls-client")]
pub fn try_new() -> Result<Self, Error> {
use std::sync::Arc;

let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
Expand All @@ -100,15 +127,15 @@ impl DefaultClient {
.with_no_client_auth();
tls_connector_alpn_h2.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];

Self {
Ok(Self {
tls_connector_no_alpn: tokio_rustls::TlsConnector::from(Arc::new(
tls_connector_no_alpn,
)),
tls_connector_alpn_h2: tokio_rustls::TlsConnector::from(Arc::new(
tls_connector_alpn_h2,
)),
with_upgrades: false,
}
})
}

/// Enable HTTP upgrades
Expand Down Expand Up @@ -204,7 +231,9 @@ impl DefaultClient {
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
let host = uri.host().ok_or_else(|| Error::InvalidHost(uri.clone()))?;
let host = uri
.host()
.ok_or_else(|| Error::InvalidHost(Box::new(uri.clone())))?;
let port =
uri.port_u16()
.unwrap_or(if uri.scheme() == Some(&hyper::http::uri::Scheme::HTTPS) {
Expand All @@ -223,18 +252,18 @@ impl DefaultClient {
.tls_connector(http_version)
.connect(host, tcp)
.await
.map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
.map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;
#[cfg(feature = "rustls-client")]
let tls = self
.tls_connector(http_version)
.connect(
host.to_string()
.try_into()
.map_err(|_| Error::InvalidHost(uri.clone()))?,
.map_err(|_| Error::InvalidHost(Box::new(uri.clone())))?,
tcp,
)
.await
.map_err(|err| Error::TlsConnectError(uri.clone(), err))?;
.map_err(|err| Error::TlsConnectError(Box::new(uri.clone()), err))?;

#[cfg(feature = "native-tls-client")]
let is_h2 = matches!(
Expand All @@ -251,7 +280,7 @@ impl DefaultClient {
let (sender, conn) = client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(TokioIo::new(tls))
.await
.map_err(|err| Error::ConnectError(uri.clone(), err))?;
.map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;

tokio::spawn(conn);

Expand All @@ -262,7 +291,7 @@ impl DefaultClient {
.title_case_headers(true)
.handshake(TokioIo::new(tls))
.await
.map_err(|err| Error::ConnectError(uri.clone(), err))?;
.map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;

tokio::spawn(conn.with_upgrades());

Expand All @@ -274,7 +303,7 @@ impl DefaultClient {
.title_case_headers(true)
.handshake(TokioIo::new(tcp))
.await
.map_err(|err| Error::ConnectError(uri.clone(), err))?;
.map_err(|err| Error::ConnectError(Box::new(uri.clone()), err))?;
tokio::spawn(conn.with_upgrades());
Ok(SendRequest::Http1(sender))
}
Expand Down Expand Up @@ -312,7 +341,10 @@ where
}
}
}
remove_authority(&mut req);
if let Err(err) = remove_authority(&mut req) {
tracing::error!("Failed to remove authority from URI: {}", err);
// Continue with the original request if URI modification fails
}
sender.send_request(req).await
}
SendRequest::Http2(sender) => {
Expand All @@ -336,9 +368,10 @@ impl<B> SendRequest<B> {
}
}

fn remove_authority<B>(req: &mut Request<B>) {
fn remove_authority<B>(req: &mut Request<B>) -> Result<(), hyper::http::uri::InvalidUriParts> {
let mut parts = req.uri().clone().into_parts();
parts.scheme = None;
parts.authority = None;
*req.uri_mut() = Uri::from_parts(parts).unwrap();
*req.uri_mut() = Uri::from_parts(parts)?;
Ok(())
}
33 changes: 23 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,19 @@ where
fn get_certified_key(&self, host: String) -> Option<CertifiedKeyDer> {
self.root_cert.as_ref().and_then(|root_cert| {
if let Some(cache) = self.cert_cache.as_ref() {
Some(cache.get_with(host.clone(), move || {
generate_cert(host, root_cert.borrow())
.map_err(|err| {
tracing::error!("Failed to generate certificate for host: {}", err);
})
.unwrap()
}))
// Try to get from cache, but handle generation errors gracefully
cache
.try_get_with(host.clone(), move || {
generate_cert(host, root_cert.borrow())
})
.map_err(|err| {
tracing::error!("Failed to generate certificate for host: {}", err);
})
.ok()
} else {
generate_cert(host, root_cert.borrow())
.map_err(|err| {
tracing::error!("Failed to generate certificate: {}", err);
tracing::error!("Failed to generate certificate for host: {}", err);
})
.ok()
}
Expand Down Expand Up @@ -300,7 +302,18 @@ fn inject_authority<B>(request_middleman: &mut Request<B>, authority: hyper::htt
let mut parts = request_middleman.uri().clone().into_parts();
parts.scheme = Some(hyper::http::uri::Scheme::HTTPS);
if parts.authority.is_none() {
parts.authority = Some(authority);
parts.authority = Some(authority.clone());
}

match hyper::http::uri::Uri::from_parts(parts) {
Ok(uri) => *request_middleman.uri_mut() = uri,
Err(err) => {
tracing::error!(
"Failed to inject authority '{}' into URI: {}",
authority,
err
);
// Keep the original URI if injection fails
}
}
*request_middleman.uri_mut() = hyper::http::uri::Uri::from_parts(parts).unwrap();
}
Loading