diff --git a/src/lib.rs b/src/lib.rs index 78659da..eeea13c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -218,10 +218,8 @@ where { inner: Proxy, connector: TlsConnector, - /// The address of the proxy server - target: SocketAddr, - /// The subject name of the proxy server - target_name: ServerName<'static>, + /// The host and port of the proxy server + target: String, /// Certificate chain for client auth cert_chain: Option>>, } @@ -230,8 +228,7 @@ impl ProxyClient { pub async fn new( cert_and_key: Option, address: impl ToSocketAddrs, - server_address: SocketAddr, - server_name: ServerName<'static>, + server_name: String, local_attestation_platform: L, remote_attestation_platform: R, ) -> Result { @@ -257,7 +254,6 @@ impl ProxyClient { Self::new_with_tls_config( client_config.into(), address, - server_address, server_name, local_attestation_platform, remote_attestation_platform, @@ -272,8 +268,7 @@ impl ProxyClient { async fn new_with_tls_config( client_config: Arc, local: impl ToSocketAddrs, - target: SocketAddr, - target_name: ServerName<'static>, + target_name: String, local_attestation_platform: L, remote_attestation_platform: R, cert_chain: Option>>, @@ -290,8 +285,7 @@ impl ProxyClient { Ok(Self { inner, connector, - target, - target_name, + target: host_to_host_with_port(&target_name), cert_chain, }) } @@ -301,8 +295,7 @@ impl ProxyClient { let (inbound, _client_addr) = self.inner.listener.accept().await?; let connector = self.connector.clone(); - let target_name = self.target_name.clone(); - let target = self.target; + let target = self.target.clone(); let local_attestation_platform = self.inner.local_attestation_platform.clone(); let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); let cert_chain = self.cert_chain.clone(); @@ -312,7 +305,6 @@ impl ProxyClient { inbound, connector, target, - target_name, cert_chain, local_attestation_platform, remote_attestation_platform, @@ -335,14 +327,15 @@ impl ProxyClient { async fn handle_connection( inbound: TcpStream, connector: TlsConnector, - target: SocketAddr, - target_name: ServerName<'static>, + target: String, cert_chain: Option>>, local_attestation_platform: L, remote_attestation_platform: R, ) -> Result<(), ProxyError> { - let out = TcpStream::connect(target).await?; - let mut tls_stream = connector.connect(target_name, out).await?; + let out = TcpStream::connect(&target).await?; + let mut tls_stream = connector + .connect(server_name_from_host(&target)?, out) + .await?; let (_io, server_connection) = tls_stream.get_ref(); @@ -394,8 +387,7 @@ impl ProxyClient { /// Just get the attested remote certificate, with no client authentication pub async fn get_tls_cert( - server_address: SocketAddr, - server_name: ServerName<'static>, + server_name: String, remote_attestation_platform: R, ) -> Result>, ProxyError> { let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); @@ -403,7 +395,6 @@ pub async fn get_tls_cert( .with_root_certificates(root_store) .with_no_client_auth(); get_tls_cert_with_config( - server_address, server_name, remote_attestation_platform, client_config.into(), @@ -412,15 +403,16 @@ pub async fn get_tls_cert( } async fn get_tls_cert_with_config( - server_address: SocketAddr, - server_name: ServerName<'static>, + server_name: String, remote_attestation_platform: R, client_config: Arc, ) -> Result>, ProxyError> { let connector = TlsConnector::from(client_config); - let out = TcpStream::connect(server_address).await?; - let mut tls_stream = connector.connect(server_name, out).await?; + let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?; + let mut tls_stream = connector + .connect(server_name_from_host(&server_name)?, out) + .await?; let (_io, server_connection) = tls_stream.get_ref(); @@ -467,6 +459,8 @@ pub enum ProxyError { Attestation(#[from] AttestationError), #[error("Integer conversion: {0}")] IntConversion(#[from] TryFromIntError), + #[error("Bad host name: {0}")] + BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), } /// Given a byte array, encode its length as a 4 byte big endian u32 @@ -475,6 +469,27 @@ fn length_prefix(input: &[u8]) -> [u8; 4] { len.to_be_bytes() } +fn host_to_host_with_port(host: &str) -> String { + if host.contains(':') { + host.to_string() + } else { + format!("{host}:443") + } +} + +fn server_name_from_host( + host: &str, +) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { + // If host contains ':', try to split off the port. + let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); + + // If the host is an IPv6 literal in brackets like "[::1]:443", + // remove the brackets for SNI (SNI allows bare IPv6 too). + let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); + + ServerName::try_from(host_part.to_string()) +} + #[cfg(test)] mod tests { use super::*; @@ -486,9 +501,8 @@ mod tests { #[tokio::test] async fn http_proxy() { let target_addr = example_http_service().await; - let target_name = "name".to_string(); - let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let proxy_server = ProxyServer::new_with_tls_config( @@ -510,9 +524,8 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, - "127.0.0.1:0", - proxy_addr, - target_name.try_into().unwrap(), + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), NoAttestation, MockAttestation, None, @@ -539,12 +552,11 @@ mod tests { #[tokio::test] async fn http_proxy_mutual_attestation() { let target_addr = example_http_service().await; - let target_name = "name".to_string(); let (server_cert_chain, server_private_key) = - generate_certificate_chain(target_name.clone()); + generate_certificate_chain("127.0.0.1".parse().unwrap()); let (client_cert_chain, client_private_key) = - generate_certificate_chain(target_name.clone()); + generate_certificate_chain("127.0.0.1".parse().unwrap()); let ( (_client_tls_server_config, client_tls_client_config), @@ -576,8 +588,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_tls_client_config, "127.0.0.1:0", - proxy_addr, - target_name.try_into().unwrap(), + proxy_addr.to_string(), MockAttestation, MockAttestation, Some(client_cert_chain), @@ -604,9 +615,8 @@ mod tests { #[tokio::test] async fn raw_tcp_proxy() { let target_addr = example_service().await; - let target_name = "name".to_string(); - let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let local_attestation_platform = MockAttestation; @@ -631,8 +641,7 @@ mod tests { let proxy_client = ProxyClient::new_with_tls_config( client_config, "127.0.0.1:0", - proxy_server_addr, - target_name.try_into().unwrap(), + proxy_server_addr.to_string(), NoAttestation, MockAttestation, None, @@ -657,9 +666,8 @@ mod tests { #[tokio::test] async fn test_get_tls_cert() { let target_addr = example_service().await; - let target_name = "name".to_string(); - let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); let local_attestation_platform = MockAttestation; @@ -682,8 +690,7 @@ mod tests { }); let retrieved_chain = get_tls_cert_with_config( - proxy_server_addr, - target_name.try_into().unwrap(), + proxy_server_addr.to_string(), MockAttestation, client_config, ) diff --git a/src/main.rs b/src/main.rs index 2148e77..1b5f6d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,14 +19,10 @@ enum CliCommand { /// Run a proxy client Client { /// Socket address to listen on - #[arg(short, long)] + #[arg(short, long, default_value = "0.0.0.0:0")] address: SocketAddr, - /// The socket address of the proxy server - #[arg(short, long)] - server_address: SocketAddr, - /// The domain name of the proxy server - #[arg(long)] - server_name: String, + /// The hostname:port or ip:port of the proxy server (port defaults to 443) + server: String, /// The path to a PEM encoded private key for client authentication #[arg(long)] private_key: Option, @@ -37,10 +33,9 @@ enum CliCommand { /// Run a proxy server Server { /// Socket address to listen on - #[arg(short, long)] + #[arg(short, long, default_value = "0.0.0.0:0")] address: SocketAddr, /// Socket address of the target service to forward traffic to - #[arg(short, long)] target_address: SocketAddr, /// The path to a PEM encoded private key #[arg(long)] @@ -55,12 +50,8 @@ enum CliCommand { }, /// Retrieve the attested TLS certificate from a proxy server GetTlsCert { - /// The socket address of the proxy server - #[arg(short, long)] - server_address: SocketAddr, - /// The domain name of the proxy server - #[arg(long)] - server_name: String, + /// The hostname:port or ip:port of the proxy server (port defaults to 443) + server: String, }, } @@ -71,8 +62,7 @@ async fn main() -> anyhow::Result<()> { match cli.command { CliCommand::Client { address, - server_name, - server_address, + server, private_key, cert_chain, } => { @@ -92,8 +82,7 @@ async fn main() -> anyhow::Result<()> { let client = ProxyClient::new( tls_cert_and_chain, address, - server_address, - server_name.try_into()?, + server, NoAttestation, MockAttestation, ) @@ -132,12 +121,8 @@ async fn main() -> anyhow::Result<()> { } } } - CliCommand::GetTlsCert { - server_address, - server_name, - } => { - let cert_chain = - get_tls_cert(server_address, server_name.try_into()?, MockAttestation).await?; + CliCommand::GetTlsCert { server } => { + let cert_chain = get_tls_cert(server, MockAttestation).await?; println!("{}", certs_to_pem_string(&cert_chain)?); } } diff --git a/src/test_helpers.rs b/src/test_helpers.rs index a45c6d9..5d4a349 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -1,5 +1,7 @@ -use rcgen::generate_simple_self_signed; -use std::{net::SocketAddr, sync::Arc}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, +}; use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio_rustls::rustls::{ @@ -10,16 +12,19 @@ use tokio_rustls::rustls::{ /// Helper to generate a self-signed certificate for testing pub fn generate_certificate_chain( - name: String, + ip: IpAddr, ) -> (Vec>, PrivateKeyDer<'static>) { - let subject_alt_names = vec![name]; - let cert_key = generate_simple_self_signed(subject_alt_names) - .expect("Failed to generate self-signed certificate"); - - let certs = vec![CertificateDer::from(cert_key.cert)]; - let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( - cert_key.signing_key.serialize_der(), - )); + let mut params = rcgen::CertificateParams::new(vec![]).unwrap(); + params.subject_alt_names.push(rcgen::SanType::IpAddress(ip)); + params + .distinguished_name + .push(rcgen::DnType::CommonName, ip.to_string()); + + let keypair = rcgen::KeyPair::generate().unwrap(); + let cert = params.self_signed(&keypair).unwrap(); + + let certs = vec![CertificateDer::from(cert)]; + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der())); (certs, key) }