diff --git a/Cargo.lock b/Cargo.lock index 9a98a4c..d1da886 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,6 +119,7 @@ dependencies = [ "anyhow", "axum", "clap", + "pem-rfc7468", "rcgen", "reqwest", "rustls-pemfile", @@ -217,6 +218,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bindgen" version = "0.72.1" @@ -987,6 +994,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" diff --git a/Cargo.toml b/Cargo.toml index 3c62e41..a594bef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ clap = { version = "4.5.51", features = ["derive"] } webpki-roots = "1.0.4" rustls-pemfile = "2.2.0" anyhow = "1.0.100" +pem-rfc7468 = { version = "0.7.0", features = ["std"] } [dev-dependencies] rcgen = "0.14.5" diff --git a/README.md b/README.md index 28fd702..4aff9d1 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,10 @@ This is a work-in-progress crate designed to be an alternative to [`cvm-reverse-proxy`](https://github.com/flashbots/cvm-reverse-proxy). -It offers two components: -- a proxy server, which accepts TLS connections from a proxy client, sends an attestation and then forwards traffic to a target CVM service. -- a proxy client, which accepts connections from elsewhere, connects to and verifies the attestation from the proxy server, and then forwards traffic to it over TLS. +It has three commands: +- `server` - run a proxy server, which accepts TLS connections from a proxy client, sends an attestation and then forwards traffic to a target CVM service. +- `client` - run a proxy client, which accepts connections from elsewhere, connects to and verifies the attestation from the proxy server, and then forwards traffic to it over TLS. +- `get-tls-cert` - connects to a proxy-server, verify the attestation, and if successful write the server's PEM-encoded TLS certificate chain to standard out. This can be used to make subsequent connections to services using this certificate over regular TLS. Unlike `cvm-reverse-proxy`, this uses post-handshake remote-attested TLS, meaning regular CA-signed TLS certificates can be used. diff --git a/src/lib.rs b/src/lib.rs index 9e3a3e2..78659da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -392,6 +392,64 @@ impl ProxyClient { } } +/// Just get the attested remote certificate, with no client authentication +pub async fn get_tls_cert( + server_address: SocketAddr, + server_name: ServerName<'static>, + remote_attestation_platform: R, +) -> Result>, ProxyError> { + let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + get_tls_cert_with_config( + server_address, + server_name, + remote_attestation_platform, + client_config.into(), + ) + .await +} + +async fn get_tls_cert_with_config( + server_address: SocketAddr, + server_name: ServerName<'static>, + remote_attestation_platform: R, + client_config: Arc, +) -> Result>, ProxyError> { + let connector = TlsConnector::from(client_config); + + let out = TcpStream::connect(server_address).await?; + let mut tls_stream = connector.connect(server_name, out).await?; + + let (_io, server_connection) = tls_stream.get_ref(); + + let mut exporter = [0u8; 32]; + server_connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(ProxyError::NoCertificate)? + .to_owned(); + + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + if remote_attestation_platform.is_cvm() { + remote_attestation_platform.verify_attestation(buf, &remote_cert_chain, exporter)?; + } + + Ok(remote_cert_chain) +} + /// An error when running a proxy client or server #[derive(Error, Debug)] pub enum ProxyError { @@ -595,4 +653,43 @@ mod tests { assert_eq!(buf[..], b"some data"[..]); } + + #[tokio::test] + async fn test_get_tls_cert() { + let target_addr = example_service().await; + let target_name = "name".to_string(); + + let (cert_chain, private_key) = generate_certificate_chain(target_name.clone()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let local_attestation_platform = MockAttestation; + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain.clone(), + server_config, + "127.0.0.1:0", + target_addr, + local_attestation_platform, + NoAttestation, + ) + .await + .unwrap(); + + let proxy_server_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let retrieved_chain = get_tls_cert_with_config( + proxy_server_addr, + target_name.try_into().unwrap(), + MockAttestation, + client_config, + ) + .await + .unwrap(); + + assert_eq!(retrieved_chain, cert_chain); + } } diff --git a/src/main.rs b/src/main.rs index be3cd75..2148e77 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,22 +3,24 @@ use clap::{Parser, Subcommand}; use std::{fs::File, net::SocketAddr, path::PathBuf}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer}; -use attested_tls_proxy::{MockAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey}; +use attested_tls_proxy::{ + get_tls_cert, MockAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey, +}; #[derive(Parser, Debug, Clone)] #[clap(version, about, long_about = None)] struct Cli { #[clap(subcommand)] command: CliCommand, - /// Socket address to listen on - #[arg(short, long)] - address: SocketAddr, } #[derive(Subcommand, Debug, Clone)] enum CliCommand { /// Run a proxy client Client { + /// Socket address to listen on + #[arg(short, long)] + address: SocketAddr, /// The socket address of the proxy server #[arg(short, long)] server_address: SocketAddr, @@ -34,6 +36,9 @@ enum CliCommand { }, /// Run a proxy server Server { + /// Socket address to listen on + #[arg(short, long)] + address: SocketAddr, /// Socket address of the target service to forward traffic to #[arg(short, long)] target_address: SocketAddr, @@ -48,6 +53,15 @@ enum CliCommand { #[arg(long)] client_auth: bool, }, + /// Retrieve the attested TLS certificate from a proxy server + GetTlsCert { + /// The socket address of the proxy server + #[arg(short, long)] + server_address: SocketAddr, + /// The domain name of the proxy server + #[arg(long)] + server_name: String, + }, } #[tokio::main] @@ -56,6 +70,7 @@ async fn main() -> anyhow::Result<()> { match cli.command { CliCommand::Client { + address, server_name, server_address, private_key, @@ -76,7 +91,7 @@ async fn main() -> anyhow::Result<()> { let client = ProxyClient::new( tls_cert_and_chain, - cli.address, + address, server_address, server_name.try_into()?, NoAttestation, @@ -91,6 +106,7 @@ async fn main() -> anyhow::Result<()> { } } CliCommand::Server { + address, target_address, private_key, cert_chain, @@ -102,7 +118,7 @@ async fn main() -> anyhow::Result<()> { let server = ProxyServer::new( tls_cert_and_chain, - cli.address, + address, target_address, local_attestation, remote_attestation, @@ -116,7 +132,17 @@ async fn main() -> anyhow::Result<()> { } } } + CliCommand::GetTlsCert { + server_address, + server_name, + } => { + let cert_chain = + get_tls_cert(server_address, server_name.try_into()?, MockAttestation).await?; + println!("{}", certs_to_pem_string(&cert_chain)?); + } } + + Ok(()) } /// Load TLS details from storage @@ -129,12 +155,12 @@ fn load_tls_cert_and_key( Ok(TlsCertAndKey { key, cert_chain }) } -pub fn load_certs_pem(path: PathBuf) -> std::io::Result>> { +fn load_certs_pem(path: PathBuf) -> std::io::Result>> { rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?)) .collect::, _>>() } -pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result> { +fn load_private_key_pem(path: PathBuf) -> anyhow::Result> { let mut reader = std::io::BufReader::new(File::open(path)?); // Tries to read the key as PKCS#8, PKCS#1, or SEC1 @@ -144,3 +170,15 @@ pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result]) -> Result { + let mut out = String::new(); + for cert in certs { + let block = + pem_rfc7468::encode_string("CERTIFICATE", pem_rfc7468::LineEnding::LF, cert.as_ref())?; + out.push_str(&block); + out.push('\n'); + } + Ok(out) +}