diff --git a/src/attestation.rs b/src/attestation.rs new file mode 100644 index 0000000..16ed270 --- /dev/null +++ b/src/attestation.rs @@ -0,0 +1,125 @@ +use sha2::{Digest, Sha256}; +use thiserror::Error; +use tokio_rustls::rustls::pki_types::CertificateDer; +use x509_parser::prelude::*; + +/// Represents a CVM technology with quote generation and verification +pub trait AttestationPlatform: Clone + Send + 'static { + fn is_cvm(&self) -> bool; + + fn create_attestation( + &self, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result, AttestationError>; + + fn verify_attestation( + &self, + input: Vec, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result<(), AttestationError>; +} + +/// For testing +#[derive(Clone)] +pub struct MockAttestation; + +impl AttestationPlatform for MockAttestation { + fn is_cvm(&self) -> bool { + true + } + + /// Mocks creating an attestation + fn create_attestation( + &self, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result, AttestationError> { + let mut quote_input = [0u8; 64]; + let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; + quote_input[..32].copy_from_slice(&pki_hash); + quote_input[32..].copy_from_slice(&exporter); + Ok(quote_input.to_vec()) + } + + /// Mocks verifying an attestation + fn verify_attestation( + &self, + input: Vec, + cert_chain: &[CertificateDer<'_>], + exporter: [u8; 32], + ) -> Result<(), AttestationError> { + let mut quote_input = [0u8; 64]; + let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; + quote_input[..32].copy_from_slice(&pki_hash); + quote_input[32..].copy_from_slice(&exporter); + + if input != quote_input { + return Err(AttestationError::InputMismatch); + } + Ok(()) + } +} + +/// For no CVM platform (eg: for one-sided remote-attested TLS) +#[derive(Clone)] +pub struct NoAttestation; + +impl AttestationPlatform for NoAttestation { + fn is_cvm(&self) -> bool { + false + } + + /// Mocks creating an attestation + fn create_attestation( + &self, + _cert_chain: &[CertificateDer<'_>], + _exporter: [u8; 32], + ) -> Result, AttestationError> { + Ok(Vec::new()) + } + + /// Mocks verifying an attestation + fn verify_attestation( + &self, + input: Vec, + _cert_chain: &[CertificateDer<'_>], + _exporter: [u8; 32], + ) -> Result<(), AttestationError> { + if input.is_empty() { + Ok(()) + } else { + Err(AttestationError::AttestationGivenWhenNoneExpected) + } + } +} + +/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate +fn get_pki_hash_from_certificate_chain( + cert_chain: &[CertificateDer<'_>], +) -> Result<[u8; 32], AttestationError> { + let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; + let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; + let public_key = &cert.tbs_certificate.subject_pki; + let key_bytes = public_key.subject_public_key.as_ref(); + + let mut hasher = Sha256::new(); + hasher.update(key_bytes); + Ok(hasher.finalize().into()) +} + +/// An error when generating or verifying an attestation +#[derive(Error, Debug)] +pub enum AttestationError { + #[error("Certificate chain is empty")] + NoCertificate, + #[error("X509 parse: {0}")] + X509Parse(#[from] x509_parser::asn1_rs::Err), + #[error("X509: {0}")] + X509(#[from] x509_parser::error::X509Error), + #[error("Quote input is not as expected")] + InputMismatch, + #[error("Configuration mismatch - expected no remote attestation")] + AttestationGivenWhenNoneExpected, +} diff --git a/src/lib.rs b/src/lib.rs index 63c0d39..82707e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,12 @@ -use sha2::{Digest, Sha256}; +mod attestation; + +pub use attestation::{AttestationPlatform, MockAttestation, NoAttestation}; +use tokio_rustls::rustls::server::WebPkiClientVerifier; + +#[cfg(test)] +mod test_helpers; + use std::{net::SocketAddr, sync::Arc}; -use thiserror::Error; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; @@ -9,72 +15,126 @@ use tokio_rustls::{ rustls::{ClientConfig, ServerConfig}, TlsAcceptor, TlsConnector, }; -use x509_parser::prelude::*; /// The label used when exporting key material from a TLS session const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +pub struct TlsCertAndKey { + pub cert_chain: Vec>, + pub key: PrivateKeyDer<'static>, +} + +struct Proxy +where + L: AttestationPlatform, + R: AttestationPlatform, +{ + /// The underlying TCP listener + listener: TcpListener, + /// Type of CVM platform we run on (including none) + local_attestation_platform: L, + /// Type of CVM platform the remote party runs on (including none) + remote_attestation_platform: R, +} + /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address -pub struct ProxyServer { +pub struct ProxyServer +where + L: AttestationPlatform, + R: AttestationPlatform, +{ + inner: Proxy, /// The certificate chain cert_chain: Vec>, /// For accepting TLS connections acceptor: TlsAcceptor, - /// The underlying TCP listener - listener: TcpListener, /// The address of the target service we are proxying to target: SocketAddr, - attestation_platform: MockAttestation, } -impl ProxyServer { +impl ProxyServer { pub async fn new( - cert_chain: Vec>, - key: PrivateKeyDer<'static>, + cert_and_key: TlsCertAndKey, local: impl ToSocketAddrs, target: SocketAddr, + local_attestation_platform: L, + remote_attestation_platform: R, + client_auth: bool, ) -> Self { - let server_config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(cert_chain.clone(), key) - .expect("Failed to create rustls server config"); + if remote_attestation_platform.is_cvm() && !client_auth { + panic!("Client auth is required when the client is running in a CVM"); + } - Self::new_with_tls_config(cert_chain, server_config.into(), local, target).await + let server_config = if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)) + .build() + .expect("invalid client verifier"); + ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key) + .expect("Failed to create rustls server config") + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key) + .expect("Failed to create rustls server config") + }; + + Self::new_with_tls_config( + cert_and_key.cert_chain, + server_config.into(), + local, + target, + local_attestation_platform, + remote_attestation_platform, + ) + .await } /// Start with preconfigured TLS - pub async fn new_with_tls_config( + /// + /// This is not public as it allows dangerous configuration + async fn new_with_tls_config( cert_chain: Vec>, server_config: Arc, local: impl ToSocketAddrs, target: SocketAddr, + local_attestation_platform: L, + remote_attestation_platform: R, ) -> Self { let acceptor = tokio_rustls::TlsAcceptor::from(server_config); let listener = TcpListener::bind(local).await.unwrap(); + let inner = Proxy { + listener, + local_attestation_platform, + remote_attestation_platform, + }; Self { - cert_chain, acceptor, - listener, target, - attestation_platform: MockAttestation, + inner, + cert_chain, } } /// Accept an incoming connection pub async fn accept(&self) -> io::Result<()> { - let (inbound, _client_addr) = self.listener.accept().await.unwrap(); + let (inbound, _client_addr) = self.inner.listener.accept().await.unwrap(); let acceptor = self.acceptor.clone(); let target = self.target; let cert_chain = self.cert_chain.clone(); - let attestation_platform = self.attestation_platform.clone(); + let local_attestation_platform = self.inner.local_attestation_platform.clone(); + let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); tokio::spawn(async move { let mut tls_stream = acceptor.accept(inbound).await.unwrap(); - let (_io, server_connection) = tls_stream.get_ref(); + let (_io, connection) = tls_stream.get_ref(); let mut exporter = [0u8; 32]; - server_connection + connection .export_keying_material( &mut exporter, EXPORTER_LABEL, @@ -82,7 +142,16 @@ impl ProxyServer { ) .unwrap(); - let attestation = attestation_platform.create_attestation(&cert_chain, exporter); + let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); + + let attestation = if local_attestation_platform.is_cvm() { + local_attestation_platform + .create_attestation(&cert_chain, exporter) + .unwrap() + } else { + Vec::new() + }; + let attestation_length_prefix = length_prefix(&attestation); tls_stream @@ -90,10 +159,20 @@ impl ProxyServer { .await .unwrap(); - tls_stream - .write_all(&attestation_platform.create_attestation(&cert_chain, exporter)) - .await - .unwrap(); + tls_stream.write_all(&attestation).await.unwrap(); + + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await.unwrap(); + let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap(); + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await.unwrap(); + + if remote_attestation_platform.is_cvm() { + remote_attestation_platform + .verify_attestation(buf, &remote_cert_chain.unwrap(), exporter) + .unwrap(); + } let outbound = TcpStream::connect(target).await.unwrap(); @@ -107,57 +186,104 @@ impl ProxyServer { Ok(()) } + + pub fn local_addr(&self) -> std::io::Result { + self.inner.listener.local_addr() + } } -pub struct ProxyClient { +pub struct ProxyClient +where + L: AttestationPlatform, + R: AttestationPlatform, +{ + inner: Proxy, connector: TlsConnector, - listener: TcpListener, /// The address of the proxy server target: SocketAddr, /// The subject name of the proxy server target_name: ServerName<'static>, - attestation_platform: MockAttestation, + /// Certificate chain for client auth + cert_chain: Option>>, } -impl ProxyClient { +impl ProxyClient { pub async fn new( + cert_and_key: Option, address: impl ToSocketAddrs, server_address: SocketAddr, server_name: ServerName<'static>, + local_attestation_platform: L, + remote_attestation_platform: R, ) -> Self { + if local_attestation_platform.is_cvm() && cert_and_key.is_none() { + panic!("Client auth is required when the client is running in a CVM"); + } + let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let client_config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - Self::new_with_tls_config(client_config.into(), address, server_address, server_name).await + let client_config = if let Some(ref cert_and_key) = cert_and_key { + ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + ) + .unwrap() + } else { + ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + Self::new_with_tls_config( + client_config.into(), + address, + server_address, + server_name, + local_attestation_platform, + remote_attestation_platform, + cert_and_key.map(|c| c.cert_chain), + ) + .await } - pub async fn new_with_tls_config( + async fn new_with_tls_config( client_config: Arc, local: impl ToSocketAddrs, target: SocketAddr, target_name: ServerName<'static>, + local_attestation_platform: L, + remote_attestation_platform: R, + cert_chain: Option>>, ) -> Self { let listener = TcpListener::bind(local).await.unwrap(); let connector = TlsConnector::from(client_config.clone()); + let inner = Proxy { + listener, + local_attestation_platform, + remote_attestation_platform, + }; + Self { + inner, connector, - listener, target, target_name, - attestation_platform: MockAttestation, + cert_chain, } } pub async fn accept(&self) -> io::Result<()> { - let (inbound, _client_addr) = self.listener.accept().await.unwrap(); + let (inbound, _client_addr) = self.inner.listener.accept().await.unwrap(); let connector = self.connector.clone(); let target_name = self.target_name.clone(); let target = self.target; - let attestation_platform = self.attestation_platform.clone(); + let local_attestation_platform = self.inner.local_attestation_platform.clone(); + let remote_attestation_platform = self.inner.remote_attestation_platform.clone(); + let cert_chain = self.cert_chain.clone(); tokio::spawn(async move { let out = TcpStream::connect(target).await.unwrap(); @@ -174,7 +300,7 @@ impl ProxyClient { ) .unwrap(); - let cert_chain = server_connection.peer_certificates().unwrap().to_owned(); + let remote_cert_chain = server_connection.peer_certificates().unwrap().to_owned(); let mut length_bytes = [0; 4]; tls_stream.read_exact(&mut length_bytes).await.unwrap(); @@ -183,10 +309,29 @@ impl ProxyClient { let mut buf = vec![0; length]; tls_stream.read_exact(&mut buf).await.unwrap(); - if !attestation_platform.verify_attestation(buf, &cert_chain, exporter) { - panic!("Cannot verify attestation"); + if remote_attestation_platform.is_cvm() { + remote_attestation_platform + .verify_attestation(buf, &remote_cert_chain, exporter) + .unwrap(); + } + + let attestation = if local_attestation_platform.is_cvm() { + local_attestation_platform + .create_attestation(&cert_chain.unwrap(), exporter) + .unwrap() + } else { + Vec::new() }; + let attestation_length_prefix = length_prefix(&attestation); + + tls_stream + .write_all(&attestation_length_prefix) + .await + .unwrap(); + + tls_stream.write_all(&attestation).await.unwrap(); + let (mut inbound_reader, mut inbound_writer) = inbound.into_split(); let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream); @@ -197,45 +342,9 @@ impl ProxyClient { Ok(()) } -} - -pub trait AttestationPlatform { - fn create_attestation(&self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec; - - fn verify_attestation( - &self, - input: Vec, - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], - ) -> bool; -} - -#[derive(Clone)] -struct MockAttestation; - -impl AttestationPlatform for MockAttestation { - /// Mocks creating an attestation - fn create_attestation(&self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec { - let mut quote_input = [0u8; 64]; - let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap(); - quote_input[..32].copy_from_slice(&pki_hash); - quote_input[32..].copy_from_slice(&exporter); - quote_input.to_vec() - } - /// Mocks verifying an attestation - fn verify_attestation( - &self, - input: Vec, - cert_chain: &[CertificateDer<'_>], - exporter: [u8; 32], - ) -> bool { - let mut quote_input = [0u8; 64]; - let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap(); - quote_input[..32].copy_from_slice(&pki_hash); - quote_input[32..].copy_from_slice(&exporter); - - input == quote_input + pub fn local_addr(&self) -> std::io::Result { + self.inner.listener.local_addr() } } @@ -244,133 +353,111 @@ fn length_prefix(input: &[u8]) -> [u8; 4] { len.to_be_bytes() } -/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate -fn get_pki_hash_from_certificate_chain( - cert_chain: &[CertificateDer<'_>], -) -> Result<[u8; 32], AttestationError> { - let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; - let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; - let public_key = &cert.tbs_certificate.subject_pki; - let key_bytes = public_key.subject_public_key.as_ref(); - - let mut hasher = Sha256::new(); - hasher.update(key_bytes); - Ok(hasher.finalize().into()) -} - -/// An error when generating an attestation -#[derive(Error, Debug)] -pub enum AttestationError { - #[error("Certificate chain is empty")] - NoCertificate, - #[error("X509 parse: {0}")] - X509Parse(#[from] x509_parser::asn1_rs::Err), - #[error("X509: {0}")] - X509(#[from] x509_parser::error::X509Error), -} - #[cfg(test)] mod tests { use super::*; - use tokio::net::TcpListener; - - use rcgen::generate_simple_self_signed; - use std::{net::SocketAddr, sync::Arc}; - use tokio_rustls::rustls::{ - pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, - ClientConfig, RootCertStore, ServerConfig, + use test_helpers::{ + example_http_service, example_service, generate_certificate_chain, generate_tls_config, + generate_tls_config_with_client_auth, }; - /// Helper to generate a self-signed certificate for testing - pub fn generate_certificate_chain( - name: String, - ) -> (Vec>, PrivateKeyDer<'static>) { - let subject_alt_names = vec![name]; - let cert_key = generate_simple_self_signed(subject_alt_names) - .expect("Failed to generate self-signed certificate"); - - let certs = vec![CertificateDer::from(cert_key.cert)]; - let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( - cert_key.signing_key.serialize_der(), - )); - (certs, key) - } - - /// Helper to generate TLS configuration for testing - /// - /// For the server: A given self-signed certificate - /// For the client: A root certificate store with the server's certificate - pub fn generate_tls_config( - certificate_chain: Vec>, - key: PrivateKeyDer<'static>, - ) -> (Arc, Arc) { - let server_config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(certificate_chain.clone(), key) - .expect("Failed to create rustls server config"); - - let mut root_store = RootCertStore::empty(); - root_store.add(certificate_chain[0].clone()).unwrap(); - - let client_config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - - (Arc::new(server_config), Arc::new(client_config)) - } + #[tokio::test] + async fn http_proxy() { + let target_addr = example_http_service().await; + let target_name = "name".to_string(); - async fn example_http_service() -> SocketAddr { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let app = axum::Router::new().route("/", axum::routing::get(|| async { "foobar" })); + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr, + MockAttestation, + NoAttestation, + ) + .await; + let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { - axum::serve(listener, app).await.unwrap(); + proxy_server.accept().await.unwrap(); }); - addr - } + let proxy_client = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0", + proxy_addr, + target_name.try_into().unwrap(), + NoAttestation, + MockAttestation, + None, + ) + .await; - async fn example_service() -> SocketAddr { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { - loop { - let (mut inbound, _client_addr) = listener.accept().await.unwrap(); - inbound.write_all(b"some data").await.unwrap(); - } + proxy_client.accept().await.unwrap(); }); - addr + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap() + .text() + .await + .unwrap(); + + assert_eq!(res, "foobar"); } #[tokio::test] - async fn http_proxy() { + async fn http_proxy_mutual_attestation() { let target_addr = example_http_service().await; let target_name = "name".to_string(); - let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); - let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - - let proxy_server = - ProxyServer::new_with_tls_config(cert_chain, server_config, "127.0.0.1:0", target_addr) - .await; - let proxy_addr = proxy_server.listener.local_addr().unwrap(); + let (server_cert_chain, server_private_key) = + generate_certificate_chain(target_name.clone()); + let (client_cert_chain, client_private_key) = + generate_certificate_chain(target_name.clone()); + + let ( + (_client_tls_server_config, client_tls_client_config), + (server_tls_server_config, _server_tls_client_config), + ) = generate_tls_config_with_client_auth( + client_cert_chain.clone(), + client_private_key, + server_cert_chain.clone(), + server_private_key, + ); + + let proxy_server = ProxyServer::new_with_tls_config( + server_cert_chain, + server_tls_server_config, + "127.0.0.1:0", + target_addr, + MockAttestation, + MockAttestation, + ) + .await; + let proxy_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { proxy_server.accept().await.unwrap(); }); let proxy_client = ProxyClient::new_with_tls_config( - client_config, + client_tls_client_config, "127.0.0.1:0", proxy_addr, target_name.try_into().unwrap(), + MockAttestation, + MockAttestation, + Some(client_cert_chain), ) .await; - let proxy_client_addr = proxy_client.listener.local_addr().unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { proxy_client.accept().await.unwrap(); @@ -394,10 +481,18 @@ mod tests { let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); - let proxy_server = - ProxyServer::new_with_tls_config(cert_chain, server_config, "127.0.0.1:0", target_addr) - .await; - let proxy_server_addr = proxy_server.listener.local_addr().unwrap(); + let local_attestation_platform = MockAttestation; + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr, + local_attestation_platform, + NoAttestation, + ) + .await; + let proxy_server_addr = proxy_server.local_addr().unwrap(); tokio::spawn(async move { proxy_server.accept().await.unwrap(); @@ -408,9 +503,12 @@ mod tests { "127.0.0.1:0", proxy_server_addr, target_name.try_into().unwrap(), + NoAttestation, + MockAttestation, + None, ) .await; - let proxy_client_addr = proxy_client.listener.local_addr().unwrap(); + let proxy_client_addr = proxy_client.local_addr().unwrap(); tokio::spawn(async move { proxy_client.accept().await.unwrap(); diff --git a/src/main.rs b/src/main.rs index 9f7dcd0..29211eb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,8 @@ use clap::{Parser, Subcommand}; -use std::{fs::File, net::SocketAddr}; +use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use attested_tls_proxy::{ProxyClient, ProxyServer}; +use attested_tls_proxy::{MockAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey}; #[derive(Parser, Debug, Clone)] #[clap(version, about, long_about = None)] @@ -22,11 +22,26 @@ enum CliCommand { server_address: SocketAddr, #[arg(long)] server_name: String, + /// The path to a PEM encoded private key for client authentication + #[arg(long)] + private_key: Option, + /// The path to a PEM encoded certificate chain for client authentication + #[arg(long)] + cert_chain: Option, }, /// Run a proxy server Server { + /// Socket address of the target service to forward traffic to #[arg(short, long)] - client_address: SocketAddr, + target_address: SocketAddr, + /// The path to a PEM encoded private key + #[arg(long)] + private_key: PathBuf, + /// The path to a PEM encoded certificate chain + #[arg(long)] + cert_chain: PathBuf, + #[arg(long)] + client_auth: bool, }, } @@ -38,19 +53,45 @@ async fn main() { CliCommand::Client { server_name, server_address, + private_key, + cert_chain, } => { - let client = - ProxyClient::new(cli.address, server_address, server_name.try_into().unwrap()) - .await; + let tls_cert_and_chain = private_key + .map(|private_key| load_tls_cert_and_key(cert_chain.unwrap(), private_key)); + + let client = ProxyClient::new( + tls_cert_and_chain, + cli.address, + server_address, + server_name.try_into().unwrap(), + NoAttestation, + MockAttestation, + ) + .await; loop { client.accept().await.unwrap(); } } - CliCommand::Server { client_address } => { - let cert_chain = load_certs_pem("certs.pem").unwrap(); - let key = load_private_key_pem("key.pem"); - let server = ProxyServer::new(cert_chain, key, cli.address, client_address).await; + CliCommand::Server { + target_address, + private_key, + cert_chain, + client_auth, + } => { + let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key); + let local_attestation = MockAttestation; + let remote_attestation = NoAttestation; + + let server = ProxyServer::new( + tls_cert_and_chain, + cli.address, + target_address, + local_attestation, + remote_attestation, + client_auth, + ) + .await; loop { server.accept().await.unwrap(); @@ -59,7 +100,13 @@ async fn main() { } } -pub fn load_certs_pem(path: &str) -> std::io::Result>> { +fn load_tls_cert_and_key(cert_chain: PathBuf, private_key: PathBuf) -> TlsCertAndKey { + let key = load_private_key_pem(private_key); + let cert_chain = load_certs_pem(cert_chain).unwrap(); + TlsCertAndKey { key, cert_chain } +} + +pub fn load_certs_pem(path: PathBuf) -> std::io::Result>> { Ok( rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?)) .map(|res| res.unwrap()) @@ -67,7 +114,7 @@ pub fn load_certs_pem(path: &str) -> std::io::Result ) } -pub fn load_private_key_pem(path: &str) -> PrivateKeyDer<'static> { +pub fn load_private_key_pem(path: PathBuf) -> PrivateKeyDer<'static> { let mut reader = std::io::BufReader::new(File::open(path).unwrap()); // Tries to read the key as PKCS#8, PKCS#1, or SEC1 diff --git a/src/test_helpers.rs b/src/test_helpers.rs new file mode 100644 index 0000000..a45c6d9 --- /dev/null +++ b/src/test_helpers.rs @@ -0,0 +1,130 @@ +use rcgen::generate_simple_self_signed; +use std::{net::SocketAddr, sync::Arc}; +use tokio::io::AsyncWriteExt; +use tokio::net::TcpListener; +use tokio_rustls::rustls::{ + pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}, + server::{danger::ClientCertVerifier, WebPkiClientVerifier}, + ClientConfig, RootCertStore, ServerConfig, +}; + +/// Helper to generate a self-signed certificate for testing +pub fn generate_certificate_chain( + name: String, +) -> (Vec>, PrivateKeyDer<'static>) { + let subject_alt_names = vec![name]; + let cert_key = generate_simple_self_signed(subject_alt_names) + .expect("Failed to generate self-signed certificate"); + + let certs = vec![CertificateDer::from(cert_key.cert)]; + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from( + cert_key.signing_key.serialize_der(), + )); + (certs, key) +} + +/// Helper to generate TLS configuration for testing +/// +/// For the server: A given self-signed certificate +/// For the client: A root certificate store with the server's certificate +pub fn generate_tls_config( + certificate_chain: Vec>, + key: PrivateKeyDer<'static>, +) -> (Arc, Arc) { + let server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certificate_chain.clone(), key) + .expect("Failed to create rustls server config"); + + let mut root_store = RootCertStore::empty(); + root_store.add(certificate_chain[0].clone()).unwrap(); + + let client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + (Arc::new(server_config), Arc::new(client_config)) +} + +/// Helper to generate a mutual TLS configuration with client authentification for testing +pub fn generate_tls_config_with_client_auth( + alice_certificate_chain: Vec>, + alice_key: PrivateKeyDer<'static>, + bob_certificate_chain: Vec>, + bob_key: PrivateKeyDer<'static>, +) -> ( + (Arc, Arc), + (Arc, Arc), +) { + let (alice_client_verifier, alice_root_store) = + client_verifier_from_remote_cert(bob_certificate_chain[0].clone()); + + let alice_server_config = ServerConfig::builder() + .with_client_cert_verifier(alice_client_verifier) + .with_single_cert(alice_certificate_chain.clone(), alice_key.clone_key()) + .expect("Failed to create rustls server config"); + + let alice_client_config = ClientConfig::builder() + .with_root_certificates(alice_root_store) + .with_client_auth_cert(alice_certificate_chain.clone(), alice_key) + .unwrap(); + + let (bob_client_verifier, bob_root_store) = + client_verifier_from_remote_cert(alice_certificate_chain[0].clone()); + + let bob_server_config = ServerConfig::builder() + .with_client_cert_verifier(bob_client_verifier) + .with_single_cert(bob_certificate_chain.clone(), bob_key.clone_key()) + .expect("Failed to create rustls server config"); + + let bob_client_config = ClientConfig::builder() + .with_root_certificates(bob_root_store) + .with_client_auth_cert(bob_certificate_chain, bob_key) + .unwrap(); + + ( + (Arc::new(alice_server_config), Arc::new(alice_client_config)), + (Arc::new(bob_server_config), Arc::new(bob_client_config)), + ) +} + +fn client_verifier_from_remote_cert( + cert: CertificateDer<'static>, +) -> (Arc, RootCertStore) { + let mut root_store = RootCertStore::empty(); + root_store.add(cert).unwrap(); + + ( + WebPkiClientVerifier::builder(Arc::new(root_store.clone())) + .build() + .unwrap(), + root_store, + ) +} + +pub async fn example_http_service() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = axum::Router::new().route("/", axum::routing::get(|| async { "foobar" })); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + addr +} + +pub async fn example_service() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (mut inbound, _client_addr) = listener.accept().await.unwrap(); + inbound.write_all(b"some data").await.unwrap(); + } + }); + + addr +}