diff --git a/src/default_client.rs b/src/default_client.rs index 0412935..3b202ae 100644 --- a/src/default_client.rs +++ b/src/default_client.rs @@ -227,7 +227,12 @@ impl DefaultClient { #[cfg(feature = "rustls-client")] let tls = self .tls_connector(http_version) - .connect(host.to_string().try_into().expect("Invalid host"), tcp) + .connect( + host.to_string() + .try_into() + .map_err(|_| Error::InvalidHost(uri.clone()))?, + tcp, + ) .await .map_err(|err| Error::TlsConnectError(uri.clone(), err))?; @@ -293,10 +298,18 @@ where SendRequest::Http1(sender) => { if req.version() == hyper::Version::HTTP_2 { if let Some(authority) = req.uri().authority().cloned() { - req.headers_mut().insert( - header::HOST, - authority.as_str().parse().expect("Invalid authority"), - ); + match authority.as_str().parse::() { + Ok(host_value) => { + req.headers_mut().insert(header::HOST, host_value); + } + Err(err) => { + tracing::warn!( + "Failed to parse authority '{}' as HOST header: {}", + authority, + err + ); + } + } } } remove_authority(&mut req); diff --git a/src/lib.rs b/src/lib.rs index f51f5db..e9e50b3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,8 +77,12 @@ where Ok(async move { loop { - let Ok((stream, _)) = listener.accept().await else { - continue; + let (stream, _) = match listener.accept().await { + Ok(conn) => conn, + Err(err) => { + tracing::warn!("Failed to accept connection: {}", err); + continue; + } }; let service = service.clone(); @@ -140,12 +144,16 @@ where }; tokio::spawn(async move { - let Ok(client) = hyper::upgrade::on(req).await else { - tracing::error!( - "Bad CONNECT request: {}, Reason: Invalid Upgrade", - connect_authority - ); - return; + let client = match hyper::upgrade::on(req).await { + Ok(client) => client, + Err(err) => { + tracing::error!( + "Failed to upgrade CONNECT request for {}: {}", + connect_authority, + err + ); + return; + } }; if let Some(server_config) = proxy.server_config(connect_authority.host().to_string(), true) @@ -196,17 +204,22 @@ where .await }; - if let Err(_err) = res { - // Suppress error because if we serving HTTPS proxy server and forward to HTTPS server, it will always error when closing connection. - // tracing::error!("Error in proxy: {}", err); + if let Err(err) = res { + tracing::debug!("Connection closed: {}", err); } } else { - let Ok(mut server) = - TcpStream::connect(connect_authority.as_str()).await - else { - tracing::error!("Failed to connect to {}", connect_authority); - return; - }; + let mut server = + match TcpStream::connect(connect_authority.as_str()).await { + Ok(server) => server, + Err(err) => { + tracing::error!( + "Failed to connect to {}: {}", + connect_authority, + err + ); + return; + } + }; let _ = tokio::io::copy_bidirectional( &mut TokioIo::new(client), &mut server, @@ -229,13 +242,21 @@ where } fn get_certified_key(&self, host: String) -> Option { - self.root_cert.as_ref().map(|root_cert| { + self.root_cert.as_ref().and_then(|root_cert| { if let Some(cache) = self.cert_cache.as_ref() { - cache.get_with(host.clone(), move || { + Some(cache.get_with(host.clone(), move || { generate_cert(host, root_cert.borrow()) - }) + .map_err(|err| { + tracing::error!("Failed to generate certificate for host: {}", err); + }) + .unwrap() + })) } else { generate_cert(host, root_cert.borrow()) + .map_err(|err| { + tracing::error!("Failed to generate certificate: {}", err); + }) + .ok() } }) } diff --git a/src/tls.rs b/src/tls.rs index 8583861..0b7242e 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -5,8 +5,11 @@ pub struct CertifiedKeyDer { pub key_der: Vec, } -pub fn generate_cert(host: String, root_cert: &rcgen::CertifiedKey) -> CertifiedKeyDer { - let mut cert_params = rcgen::CertificateParams::new(vec![host.clone()]).unwrap(); +pub fn generate_cert( + host: String, + root_cert: &rcgen::CertifiedKey, +) -> Result { + let mut cert_params = rcgen::CertificateParams::new(vec![host.clone()])?; cert_params .key_usages .push(rcgen::KeyUsagePurpose::DigitalSignature); @@ -22,14 +25,12 @@ pub fn generate_cert(host: String, root_cert: &rcgen::CertifiedKey) -> Certified dn }; - let key_pair = rcgen::KeyPair::generate().unwrap(); + let key_pair = rcgen::KeyPair::generate()?; - let cert = cert_params - .signed_by(&key_pair, &root_cert.cert, &root_cert.key_pair) - .unwrap(); + let cert = cert_params.signed_by(&key_pair, &root_cert.cert, &root_cert.key_pair)?; - CertifiedKeyDer { + Ok(CertifiedKeyDer { cert_der: cert.der().to_vec(), key_der: key_pair.serialize_der(), - } + }) }