diff --git a/Cargo.lock b/Cargo.lock index 814b8d9..9a98a4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "anyhow" +version = "1.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" + [[package]] name = "asn1-rs" version = "0.7.1" @@ -110,6 +116,7 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" name = "attested-tls-proxy" version = "0.1.0" dependencies = [ + "anyhow", "axum", "clap", "rcgen", diff --git a/Cargo.toml b/Cargo.toml index 04ca13f..3c62e41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ thiserror = "2.0.17" clap = { version = "4.5.51", features = ["derive"] } webpki-roots = "1.0.4" rustls-pemfile = "2.2.0" +anyhow = "1.0.100" [dev-dependencies] rcgen = "0.14.5" diff --git a/src/attestation.rs b/src/attestation.rs index 16ed270..61d553c 100644 --- a/src/attestation.rs +++ b/src/attestation.rs @@ -5,14 +5,19 @@ use x509_parser::prelude::*; /// Represents a CVM technology with quote generation and verification pub trait AttestationPlatform: Clone + Send + 'static { + /// Whether this is CVM attestation. This should always return true except for the [NoAttestation] case. + /// + /// When false, allows TLS client to be configured without client authentication fn is_cvm(&self) -> bool; + /// Generate an attestation fn create_attestation( &self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32], ) -> Result, AttestationError>; + /// Verify the given attestation payload fn verify_attestation( &self, input: Vec, @@ -71,7 +76,7 @@ impl AttestationPlatform for NoAttestation { false } - /// Mocks creating an attestation + /// Create an empty attestation fn create_attestation( &self, _cert_chain: &[CertificateDer<'_>], @@ -80,7 +85,7 @@ impl AttestationPlatform for NoAttestation { Ok(Vec::new()) } - /// Mocks verifying an attestation + /// Ensure that an empty attestation is given fn verify_attestation( &self, input: Vec, diff --git a/src/lib.rs b/src/lib.rs index 82707e0..9e3a3e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,14 @@ mod attestation; +use attestation::AttestationError; pub use attestation::{AttestationPlatform, MockAttestation, NoAttestation}; -use tokio_rustls::rustls::server::WebPkiClientVerifier; +use thiserror::Error; +use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; #[cfg(test)] mod test_helpers; +use std::num::TryFromIntError; use std::{net::SocketAddr, sync::Arc}; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; @@ -60,26 +63,23 @@ impl ProxyServer { local_attestation_platform: L, remote_attestation_platform: R, client_auth: bool, - ) -> Self { + ) -> Result { if remote_attestation_platform.is_cvm() && !client_auth { - panic!("Client auth is required when the client is running in a CVM"); + return Err(ProxyError::NoClientAuth); } let server_config = if client_auth { let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)) - .build() - .expect("invalid client verifier"); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + ServerConfig::builder() .with_client_cert_verifier(verifier) - .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key) - .expect("Failed to create rustls server config") + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? } else { ServerConfig::builder() .with_no_client_auth() - .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key) - .expect("Failed to create rustls server config") + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? }; Self::new_with_tls_config( @@ -103,26 +103,27 @@ impl ProxyServer { target: SocketAddr, local_attestation_platform: L, remote_attestation_platform: R, - ) -> Self { + ) -> Result { let acceptor = tokio_rustls::TlsAcceptor::from(server_config); - let listener = TcpListener::bind(local).await.unwrap(); + let listener = TcpListener::bind(local).await?; let inner = Proxy { listener, local_attestation_platform, remote_attestation_platform, }; - Self { + + Ok(Self { acceptor, target, inner, cert_chain, - } + }) } /// Accept an incoming connection - pub async fn accept(&self) -> io::Result<()> { - let (inbound, _client_addr) = self.inner.listener.accept().await.unwrap(); + pub async fn accept(&self) -> Result<(), ProxyError> { + let (inbound, _client_addr) = self.inner.listener.accept().await?; let acceptor = self.acceptor.clone(); let target = self.target; @@ -130,58 +131,18 @@ impl ProxyServer { let local_attestation_platform = self.inner.local_attestation_platform.clone(); let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); tokio::spawn(async move { - let mut tls_stream = acceptor.accept(inbound).await.unwrap(); - let (_io, connection) = tls_stream.get_ref(); - - let mut exporter = [0u8; 32]; - connection - .export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - ) - .unwrap(); - - let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); - - let attestation = if local_attestation_platform.is_cvm() { - local_attestation_platform - .create_attestation(&cert_chain, exporter) - .unwrap() - } else { - Vec::new() - }; - - let attestation_length_prefix = length_prefix(&attestation); - - tls_stream - .write_all(&attestation_length_prefix) - .await - .unwrap(); - - tls_stream.write_all(&attestation).await.unwrap(); - - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await.unwrap(); - let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap(); - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await.unwrap(); - - if remote_attestation_platform.is_cvm() { - remote_attestation_platform - .verify_attestation(buf, &remote_cert_chain.unwrap(), exporter) - .unwrap(); + if let Err(err) = Self::handle_connection( + inbound, + acceptor, + target, + cert_chain, + local_attestation_platform, + remote_attestation_platform, + ) + .await + { + eprintln!("Failed to handle connection: {err}"); } - - let outbound = TcpStream::connect(target).await.unwrap(); - - let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream); - let (mut outbound_reader, mut outbound_writer) = outbound.into_split(); - - let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - tokio::try_join!(client_to_server, server_to_client).unwrap(); }); Ok(()) @@ -190,6 +151,64 @@ impl ProxyServer { pub fn local_addr(&self) -> std::io::Result { self.inner.listener.local_addr() } + + async fn handle_connection( + inbound: TcpStream, + acceptor: TlsAcceptor, + target: SocketAddr, + cert_chain: Vec>, + local_attestation_platform: L, + remote_attestation_platform: R, + ) -> Result<(), ProxyError> { + let mut tls_stream = acceptor.accept(inbound).await?; + let (_io, connection) = tls_stream.get_ref(); + + let mut exporter = [0u8; 32]; + connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); + + let attestation = if local_attestation_platform.is_cvm() { + local_attestation_platform.create_attestation(&cert_chain, exporter)? + } else { + Vec::new() + }; + + let attestation_length_prefix = length_prefix(&attestation); + + tls_stream.write_all(&attestation_length_prefix).await?; + + tls_stream.write_all(&attestation).await?; + + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + if remote_attestation_platform.is_cvm() { + remote_attestation_platform.verify_attestation( + buf, + &remote_cert_chain.ok_or(ProxyError::NoClientAuth)?, + exporter, + )?; + } + + let outbound = TcpStream::connect(target).await?; + + let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream); + let (mut outbound_reader, mut outbound_writer) = outbound.into_split(); + + let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); + let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); + tokio::try_join!(client_to_server, server_to_client)?; + Ok(()) + } } pub struct ProxyClient @@ -215,9 +234,9 @@ impl ProxyClient { server_name: ServerName<'static>, local_attestation_platform: L, remote_attestation_platform: R, - ) -> Self { + ) -> Result { if local_attestation_platform.is_cvm() && cert_and_key.is_none() { - panic!("Client auth is required when the client is running in a CVM"); + return Err(ProxyError::NoClientAuth); } let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); @@ -228,8 +247,7 @@ impl ProxyClient { .with_client_auth_cert( cert_and_key.cert_chain.clone(), cert_and_key.key.clone_key(), - ) - .unwrap() + )? } else { ClientConfig::builder() .with_root_certificates(root_store) @@ -248,6 +266,9 @@ impl ProxyClient { .await } + /// Create a new proxy with given TLS configuration + /// + /// This is private as it allows dangerous configuration but is used in tests async fn new_with_tls_config( client_config: Arc, local: impl ToSocketAddrs, @@ -256,8 +277,8 @@ impl ProxyClient { local_attestation_platform: L, remote_attestation_platform: R, cert_chain: Option>>, - ) -> Self { - let listener = TcpListener::bind(local).await.unwrap(); + ) -> Result { + let listener = TcpListener::bind(local).await?; let connector = TlsConnector::from(client_config.clone()); let inner = Proxy { @@ -266,17 +287,18 @@ impl ProxyClient { remote_attestation_platform, }; - Self { + Ok(Self { inner, connector, target, target_name, cert_chain, - } + }) } + /// Accept an incoming connection and handle it pub async fn accept(&self) -> io::Result<()> { - let (inbound, _client_addr) = self.inner.listener.accept().await.unwrap(); + let (inbound, _client_addr) = self.inner.listener.accept().await?; let connector = self.connector.clone(); let target_name = self.target_name.clone(); @@ -286,68 +308,110 @@ impl ProxyClient { let cert_chain = self.cert_chain.clone(); tokio::spawn(async move { - let out = TcpStream::connect(target).await.unwrap(); - let mut tls_stream = connector.connect(target_name, out).await.unwrap(); + if let Err(err) = Self::handle_connection( + inbound, + connector, + target, + target_name, + cert_chain, + local_attestation_platform, + remote_attestation_platform, + ) + .await + { + eprintln!("Failed to handle connection: {err}"); + } + }); - let (_io, server_connection) = tls_stream.get_ref(); + Ok(()) + } - let mut exporter = [0u8; 32]; - server_connection - .export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - ) - .unwrap(); + /// Helper to return the local socket address from the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.inner.listener.local_addr() + } - let remote_cert_chain = server_connection.peer_certificates().unwrap().to_owned(); + /// Handle an incoming connection + async fn handle_connection( + inbound: TcpStream, + connector: TlsConnector, + target: SocketAddr, + target_name: ServerName<'static>, + cert_chain: Option>>, + local_attestation_platform: L, + remote_attestation_platform: R, + ) -> Result<(), ProxyError> { + let out = TcpStream::connect(target).await?; + let mut tls_stream = connector.connect(target_name, out).await?; - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await.unwrap(); - let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap(); + let (_io, server_connection) = tls_stream.get_ref(); - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await.unwrap(); + let mut exporter = [0u8; 32]; + server_connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; - if remote_attestation_platform.is_cvm() { - remote_attestation_platform - .verify_attestation(buf, &remote_cert_chain, exporter) - .unwrap(); - } + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)? + .to_owned(); - let attestation = if local_attestation_platform.is_cvm() { - local_attestation_platform - .create_attestation(&cert_chain.unwrap(), exporter) - .unwrap() - } else { - Vec::new() - }; + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - let attestation_length_prefix = length_prefix(&attestation); + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; - tls_stream - .write_all(&attestation_length_prefix) - .await - .unwrap(); + if remote_attestation_platform.is_cvm() { + remote_attestation_platform.verify_attestation(buf, &remote_cert_chain, exporter)?; + } - tls_stream.write_all(&attestation).await.unwrap(); + let attestation = if local_attestation_platform.is_cvm() { + local_attestation_platform + .create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)? + } else { + Vec::new() + }; - let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); - let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); + let attestation_length_prefix = length_prefix(&attestation); - let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); - let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); - tokio::try_join!(client_to_server, server_to_client).unwrap(); - }); + tls_stream.write_all(&attestation_length_prefix).await?; + tls_stream.write_all(&attestation).await?; + + let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); + let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); + + let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer); + let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer); + tokio::try_join!(client_to_server, server_to_client)?; Ok(()) } +} - pub fn local_addr(&self) -> std::io::Result { - self.inner.listener.local_addr() - } +/// An error when running a proxy client or server +#[derive(Error, Debug)] +pub enum ProxyError { + #[error("Client auth is required when the client is running in a CVM")] + NoClientAuth, + #[error("Failed to get server ceritifcate")] + NoCertificate, + #[error("TLS: {0}")] + Rustls(#[from] tokio_rustls::rustls::Error), + #[error("Verifier builder: {0}")] + VerifierBuilder(#[from] VerifierBuilderError), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Attestation: {0}")] + Attestation(#[from] AttestationError), + #[error("Integer conversion: {0}")] + IntConversion(#[from] TryFromIntError), } +/// Given a byte array, encode its length as a 4 byte big endian u32 fn length_prefix(input: &[u8]) -> [u8; 4] { let len = input.len() as u32; len.to_be_bytes() @@ -377,7 +441,9 @@ mod tests { MockAttestation, NoAttestation, ) - .await; + .await + .unwrap(); + let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { @@ -393,7 +459,8 @@ mod tests { MockAttestation, None, ) - .await; + .await + .unwrap(); let proxy_client_addr = proxy_client.local_addr().unwrap(); @@ -439,7 +506,9 @@ mod tests { MockAttestation, MockAttestation, ) - .await; + .await + .unwrap(); + let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { @@ -455,7 +524,8 @@ mod tests { MockAttestation, Some(client_cert_chain), ) - .await; + .await + .unwrap(); let proxy_client_addr = proxy_client.local_addr().unwrap(); @@ -491,7 +561,9 @@ mod tests { local_attestation_platform, NoAttestation, ) - .await; + .await + .unwrap(); + let proxy_server_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { @@ -507,7 +579,9 @@ mod tests { MockAttestation, None, ) - .await; + .await + .unwrap(); + let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { diff --git a/src/main.rs b/src/main.rs index 29211eb..be3cd75 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use anyhow::{anyhow, ensure}; use clap::{Parser, Subcommand}; use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; @@ -18,8 +19,10 @@ struct Cli { enum CliCommand { /// Run a proxy client Client { + /// The socket address of the proxy server #[arg(short, long)] server_address: SocketAddr, + /// The domain name of the proxy server #[arg(long)] server_name: String, /// The path to a PEM encoded private key for client authentication @@ -40,13 +43,15 @@ enum CliCommand { /// The path to a PEM encoded certificate chain #[arg(long)] cert_chain: PathBuf, + /// Whether to use client authentication. If the client is running in a CVM this must be + /// enabled. #[arg(long)] client_auth: bool, }, } #[tokio::main] -async fn main() { +async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); match cli.command { @@ -56,21 +61,33 @@ async fn main() { private_key, cert_chain, } => { - let tls_cert_and_chain = private_key - .map(|private_key| load_tls_cert_and_key(cert_chain.unwrap(), private_key)); + let tls_cert_and_chain = if let Some(private_key) = private_key { + Some(load_tls_cert_and_key( + cert_chain.ok_or(anyhow!("Private key given but no certificate chain"))?, + private_key, + )?) + } else { + ensure!( + cert_chain.is_none(), + "Certificate chain given but no private key" + ); + None + }; let client = ProxyClient::new( tls_cert_and_chain, cli.address, server_address, - server_name.try_into().unwrap(), + server_name.try_into()?, NoAttestation, MockAttestation, ) - .await; + .await?; loop { - client.accept().await.unwrap(); + if let Err(err) = client.accept().await { + eprintln!("Failed to handle connection: {err}"); + } } } CliCommand::Server { @@ -79,7 +96,7 @@ async fn main() { cert_chain, client_auth, } => { - let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key); + let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key)?; let local_attestation = MockAttestation; let remote_attestation = NoAttestation; @@ -91,37 +108,39 @@ async fn main() { remote_attestation, client_auth, ) - .await; + .await?; loop { - server.accept().await.unwrap(); + if let Err(err) = server.accept().await { + eprintln!("Failed to handle connection: {err}"); + } } } } } -fn load_tls_cert_and_key(cert_chain: PathBuf, private_key: PathBuf) -> TlsCertAndKey { - let key = load_private_key_pem(private_key); - let cert_chain = load_certs_pem(cert_chain).unwrap(); - TlsCertAndKey { key, cert_chain } +/// Load TLS details from storage +fn load_tls_cert_and_key( + cert_chain: PathBuf, + private_key: PathBuf, +) -> anyhow::Result { + let key = load_private_key_pem(private_key)?; + let cert_chain = load_certs_pem(cert_chain)?; + Ok(TlsCertAndKey { key, cert_chain }) } pub fn load_certs_pem(path: PathBuf) -> std::io::Result>> { - Ok( - rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?)) - .map(|res| res.unwrap()) - .collect(), - ) + rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?)) + .collect::, _>>() } -pub fn load_private_key_pem(path: PathBuf) -> PrivateKeyDer<'static> { - let mut reader = std::io::BufReader::new(File::open(path).unwrap()); +pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result> { + let mut reader = std::io::BufReader::new(File::open(path)?); // Tries to read the key as PKCS#8, PKCS#1, or SEC1 let pks8_key = rustls_pemfile::pkcs8_private_keys(&mut reader) .next() - .unwrap() - .unwrap(); + .ok_or(anyhow!("No PKS8 Key"))??; - PrivateKeyDer::Pkcs8(pks8_key) + Ok(PrivateKeyDer::Pkcs8(pks8_key)) }