Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ clap = { version = "4.5.51", features = ["derive"] }
webpki-roots = "1.0.4"
rustls-pemfile = "2.2.0"
anyhow = "1.0.100"
pem-rfc7468 = { version = "0.7.0", features = ["std"] }

[dev-dependencies]
rcgen = "0.14.5"
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

This is a work-in-progress crate designed to be an alternative to [`cvm-reverse-proxy`](https://github.com/flashbots/cvm-reverse-proxy).

It offers two components:
- a proxy server, which accepts TLS connections from a proxy client, sends an attestation and then forwards traffic to a target CVM service.
- a proxy client, which accepts connections from elsewhere, connects to and verifies the attestation from the proxy server, and then forwards traffic to it over TLS.
It has three commands:
- `server` - run a proxy server, which accepts TLS connections from a proxy client, sends an attestation and then forwards traffic to a target CVM service.
- `client` - run a proxy client, which accepts connections from elsewhere, connects to and verifies the attestation from the proxy server, and then forwards traffic to it over TLS.
- `get-tls-cert` - connects to a proxy-server, verify the attestation, and if successful write the server's PEM-encoded TLS certificate chain to standard out. This can be used to make subsequent connections to services using this certificate over regular TLS.

Unlike `cvm-reverse-proxy`, this uses post-handshake remote-attested TLS, meaning regular CA-signed TLS certificates can be used.

Expand Down
97 changes: 97 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,64 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
}
}

/// Just get the attested remote certificate, with no client authentication
pub async fn get_tls_cert<R: AttestationPlatform>(
server_address: SocketAddr,
server_name: ServerName<'static>,
remote_attestation_platform: R,
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let client_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
get_tls_cert_with_config(
server_address,
server_name,
remote_attestation_platform,
client_config.into(),
)
.await
}

async fn get_tls_cert_with_config<R: AttestationPlatform>(
server_address: SocketAddr,
server_name: ServerName<'static>,
remote_attestation_platform: R,
client_config: Arc<ClientConfig>,
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
let connector = TlsConnector::from(client_config);

let out = TcpStream::connect(server_address).await?;
let mut tls_stream = connector.connect(server_name, out).await?;

let (_io, server_connection) = tls_stream.get_ref();

let mut exporter = [0u8; 32];
server_connection.export_keying_material(
&mut exporter,
EXPORTER_LABEL,
None, // context
)?;

let remote_cert_chain = server_connection
.peer_certificates()
.ok_or(ProxyError::NoCertificate)?
.to_owned();

let mut length_bytes = [0; 4];
tls_stream.read_exact(&mut length_bytes).await?;
let length: usize = u32::from_be_bytes(length_bytes).try_into()?;

let mut buf = vec![0; length];
tls_stream.read_exact(&mut buf).await?;

if remote_attestation_platform.is_cvm() {
remote_attestation_platform.verify_attestation(buf, &remote_cert_chain, exporter)?;
}

Ok(remote_cert_chain)
}

/// An error when running a proxy client or server
#[derive(Error, Debug)]
pub enum ProxyError {
Expand Down Expand Up @@ -595,4 +653,43 @@ mod tests {

assert_eq!(buf[..], b"some data"[..]);
}

#[tokio::test]
async fn test_get_tls_cert() {
let target_addr = example_service().await;
let target_name = "name".to_string();

let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);

let local_attestation_platform = MockAttestation;

let proxy_server = ProxyServer::new_with_tls_config(
cert_chain.clone(),
server_config,
"127.0.0.1:0",
target_addr,
local_attestation_platform,
NoAttestation,
)
.await
.unwrap();

let proxy_server_addr = proxy_server.local_addr().unwrap();

tokio::spawn(async move {
proxy_server.accept().await.unwrap();
});

let retrieved_chain = get_tls_cert_with_config(
proxy_server_addr,
target_name.try_into().unwrap(),
MockAttestation,
client_config,
)
.await
.unwrap();

assert_eq!(retrieved_chain, cert_chain);
}
}
54 changes: 46 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@ use clap::{Parser, Subcommand};
use std::{fs::File, net::SocketAddr, path::PathBuf};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};

use attested_tls_proxy::{MockAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey};
use attested_tls_proxy::{
get_tls_cert, MockAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey,
};

#[derive(Parser, Debug, Clone)]
#[clap(version, about, long_about = None)]
struct Cli {
#[clap(subcommand)]
command: CliCommand,
/// Socket address to listen on
#[arg(short, long)]
address: SocketAddr,
}

#[derive(Subcommand, Debug, Clone)]
enum CliCommand {
/// Run a proxy client
Client {
/// Socket address to listen on
#[arg(short, long)]
address: SocketAddr,
/// The socket address of the proxy server
#[arg(short, long)]
server_address: SocketAddr,
Expand All @@ -34,6 +36,9 @@ enum CliCommand {
},
/// Run a proxy server
Server {
/// Socket address to listen on
#[arg(short, long)]
address: SocketAddr,
/// Socket address of the target service to forward traffic to
#[arg(short, long)]
target_address: SocketAddr,
Expand All @@ -48,6 +53,15 @@ enum CliCommand {
#[arg(long)]
client_auth: bool,
},
/// Retrieve the attested TLS certificate from a proxy server
GetTlsCert {
/// The socket address of the proxy server
#[arg(short, long)]
server_address: SocketAddr,
/// The domain name of the proxy server
#[arg(long)]
server_name: String,
},
}

#[tokio::main]
Expand All @@ -56,6 +70,7 @@ async fn main() -> anyhow::Result<()> {

match cli.command {
CliCommand::Client {
address,
server_name,
server_address,
private_key,
Expand All @@ -76,7 +91,7 @@ async fn main() -> anyhow::Result<()> {

let client = ProxyClient::new(
tls_cert_and_chain,
cli.address,
address,
server_address,
server_name.try_into()?,
NoAttestation,
Expand All @@ -91,6 +106,7 @@ async fn main() -> anyhow::Result<()> {
}
}
CliCommand::Server {
address,
target_address,
private_key,
cert_chain,
Expand All @@ -102,7 +118,7 @@ async fn main() -> anyhow::Result<()> {

let server = ProxyServer::new(
tls_cert_and_chain,
cli.address,
address,
target_address,
local_attestation,
remote_attestation,
Expand All @@ -116,7 +132,17 @@ async fn main() -> anyhow::Result<()> {
}
}
}
CliCommand::GetTlsCert {
server_address,
server_name,
} => {
let cert_chain =
get_tls_cert(server_address, server_name.try_into()?, MockAttestation).await?;
println!("{}", certs_to_pem_string(&cert_chain)?);
}
}

Ok(())
}

/// Load TLS details from storage
Expand All @@ -129,12 +155,12 @@ fn load_tls_cert_and_key(
Ok(TlsCertAndKey { key, cert_chain })
}

pub fn load_certs_pem(path: PathBuf) -> std::io::Result<Vec<CertificateDer<'static>>> {
fn load_certs_pem(path: PathBuf) -> std::io::Result<Vec<CertificateDer<'static>>> {
rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?))
.collect::<Result<Vec<_>, _>>()
}

pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result<PrivateKeyDer<'static>> {
fn load_private_key_pem(path: PathBuf) -> anyhow::Result<PrivateKeyDer<'static>> {
let mut reader = std::io::BufReader::new(File::open(path)?);

// Tries to read the key as PKCS#8, PKCS#1, or SEC1
Expand All @@ -144,3 +170,15 @@ pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result<PrivateKeyDer<'stat

Ok(PrivateKeyDer::Pkcs8(pks8_key))
}

/// Given a certificate chain, convert it to a PEM encoded string
fn certs_to_pem_string(certs: &[CertificateDer<'_>]) -> Result<String, pem_rfc7468::Error> {
let mut out = String::new();
for cert in certs {
let block =
pem_rfc7468::encode_string("CERTIFICATE", pem_rfc7468::LineEnding::LF, cert.as_ref())?;
out.push_str(&block);
out.push('\n');
}
Ok(out)
}