Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 50 additions & 43 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,8 @@ where
{
inner: Proxy<L, R>,
connector: TlsConnector,
/// The address of the proxy server
target: SocketAddr,
/// The subject name of the proxy server
target_name: ServerName<'static>,
/// The host and port of the proxy server
target: String,
/// Certificate chain for client auth
cert_chain: Option<Vec<CertificateDer<'static>>>,
}
Expand All @@ -230,8 +228,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
pub async fn new(
cert_and_key: Option<TlsCertAndKey>,
address: impl ToSocketAddrs,
server_address: SocketAddr,
server_name: ServerName<'static>,
server_name: String,
local_attestation_platform: L,
remote_attestation_platform: R,
) -> Result<Self, ProxyError> {
Expand All @@ -257,7 +254,6 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
Self::new_with_tls_config(
client_config.into(),
address,
server_address,
server_name,
local_attestation_platform,
remote_attestation_platform,
Expand All @@ -272,8 +268,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
async fn new_with_tls_config(
client_config: Arc<ClientConfig>,
local: impl ToSocketAddrs,
target: SocketAddr,
target_name: ServerName<'static>,
target_name: String,
local_attestation_platform: L,
remote_attestation_platform: R,
cert_chain: Option<Vec<CertificateDer<'static>>>,
Expand All @@ -290,8 +285,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
Ok(Self {
inner,
connector,
target,
target_name,
target: host_to_host_with_port(&target_name),
cert_chain,
})
}
Expand All @@ -301,8 +295,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
let (inbound, _client_addr) = self.inner.listener.accept().await?;

let connector = self.connector.clone();
let target_name = self.target_name.clone();
let target = self.target;
let target = self.target.clone();
let local_attestation_platform = self.inner.local_attestation_platform.clone();
let remote_attestation_platform = self.inner.remote_attestation_platform.clone();
let cert_chain = self.cert_chain.clone();
Expand All @@ -312,7 +305,6 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
inbound,
connector,
target,
target_name,
cert_chain,
local_attestation_platform,
remote_attestation_platform,
Expand All @@ -335,14 +327,15 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
async fn handle_connection(
inbound: TcpStream,
connector: TlsConnector,
target: SocketAddr,
target_name: ServerName<'static>,
target: String,
cert_chain: Option<Vec<CertificateDer<'static>>>,
local_attestation_platform: L,
remote_attestation_platform: R,
) -> Result<(), ProxyError> {
let out = TcpStream::connect(target).await?;
let mut tls_stream = connector.connect(target_name, out).await?;
let out = TcpStream::connect(&target).await?;
let mut tls_stream = connector
.connect(server_name_from_host(&target)?, out)
.await?;

let (_io, server_connection) = tls_stream.get_ref();

Expand Down Expand Up @@ -394,16 +387,14 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {

/// Just get the attested remote certificate, with no client authentication
pub async fn get_tls_cert<R: AttestationPlatform>(
server_address: SocketAddr,
server_name: ServerName<'static>,
server_name: String,
remote_attestation_platform: R,
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let client_config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
get_tls_cert_with_config(
server_address,
server_name,
remote_attestation_platform,
client_config.into(),
Expand All @@ -412,15 +403,16 @@ pub async fn get_tls_cert<R: AttestationPlatform>(
}

async fn get_tls_cert_with_config<R: AttestationPlatform>(
server_address: SocketAddr,
server_name: ServerName<'static>,
server_name: String,
remote_attestation_platform: R,
client_config: Arc<ClientConfig>,
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
let connector = TlsConnector::from(client_config);

let out = TcpStream::connect(server_address).await?;
let mut tls_stream = connector.connect(server_name, out).await?;
let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?;
let mut tls_stream = connector
.connect(server_name_from_host(&server_name)?, out)
.await?;

let (_io, server_connection) = tls_stream.get_ref();

Expand Down Expand Up @@ -467,6 +459,8 @@ pub enum ProxyError {
Attestation(#[from] AttestationError),
#[error("Integer conversion: {0}")]
IntConversion(#[from] TryFromIntError),
#[error("Bad host name: {0}")]
BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError),
}

/// Given a byte array, encode its length as a 4 byte big endian u32
Expand All @@ -475,6 +469,27 @@ fn length_prefix(input: &[u8]) -> [u8; 4] {
len.to_be_bytes()
}

fn host_to_host_with_port(host: &str) -> String {
if host.contains(':') {
host.to_string()
} else {
format!("{host}:443")
}
}

fn server_name_from_host(
host: &str,
) -> Result<ServerName<'static>, tokio_rustls::rustls::pki_types::InvalidDnsNameError> {
// If host contains ':', try to split off the port.
let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host);

// If the host is an IPv6 literal in brackets like "[::1]:443",
// remove the brackets for SNI (SNI allows bare IPv6 too).
let host_part = host_part.trim_matches(|c| c == '[' || c == ']');

ServerName::try_from(host_part.to_string())
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -486,9 +501,8 @@ mod tests {
#[tokio::test]
async fn http_proxy() {
let target_addr = example_http_service().await;
let target_name = "name".to_string();

let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);

let proxy_server = ProxyServer::new_with_tls_config(
Expand All @@ -510,9 +524,8 @@ mod tests {

let proxy_client = ProxyClient::new_with_tls_config(
client_config,
"127.0.0.1:0",
proxy_addr,
target_name.try_into().unwrap(),
"127.0.0.1:0".to_string(),
proxy_addr.to_string(),
NoAttestation,
MockAttestation,
None,
Expand All @@ -539,12 +552,11 @@ mod tests {
#[tokio::test]
async fn http_proxy_mutual_attestation() {
let target_addr = example_http_service().await;
let target_name = "name".to_string();

let (server_cert_chain, server_private_key) =
generate_certificate_chain(target_name.clone());
generate_certificate_chain("127.0.0.1".parse().unwrap());
let (client_cert_chain, client_private_key) =
generate_certificate_chain(target_name.clone());
generate_certificate_chain("127.0.0.1".parse().unwrap());

let (
(_client_tls_server_config, client_tls_client_config),
Expand Down Expand Up @@ -576,8 +588,7 @@ mod tests {
let proxy_client = ProxyClient::new_with_tls_config(
client_tls_client_config,
"127.0.0.1:0",
proxy_addr,
target_name.try_into().unwrap(),
proxy_addr.to_string(),
MockAttestation,
MockAttestation,
Some(client_cert_chain),
Expand All @@ -604,9 +615,8 @@ mod tests {
#[tokio::test]
async fn raw_tcp_proxy() {
let target_addr = example_service().await;
let target_name = "name".to_string();

let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);

let local_attestation_platform = MockAttestation;
Expand All @@ -631,8 +641,7 @@ mod tests {
let proxy_client = ProxyClient::new_with_tls_config(
client_config,
"127.0.0.1:0",
proxy_server_addr,
target_name.try_into().unwrap(),
proxy_server_addr.to_string(),
NoAttestation,
MockAttestation,
None,
Expand All @@ -657,9 +666,8 @@ mod tests {
#[tokio::test]
async fn test_get_tls_cert() {
let target_addr = example_service().await;
let target_name = "name".to_string();

let (cert_chain, private_key) = generate_certificate_chain(target_name.clone());
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);

let local_attestation_platform = MockAttestation;
Expand All @@ -682,8 +690,7 @@ mod tests {
});

let retrieved_chain = get_tls_cert_with_config(
proxy_server_addr,
target_name.try_into().unwrap(),
proxy_server_addr.to_string(),
MockAttestation,
client_config,
)
Expand Down
35 changes: 10 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ enum CliCommand {
/// Run a proxy client
Client {
/// Socket address to listen on
#[arg(short, long)]
#[arg(short, long, default_value = "0.0.0.0:0")]
address: SocketAddr,
/// The socket address of the proxy server
#[arg(short, long)]
server_address: SocketAddr,
/// The domain name of the proxy server
#[arg(long)]
server_name: String,
/// The hostname:port or ip:port of the proxy server (port defaults to 443)
server: String,
/// The path to a PEM encoded private key for client authentication
#[arg(long)]
private_key: Option<PathBuf>,
Expand All @@ -37,10 +33,9 @@ enum CliCommand {
/// Run a proxy server
Server {
/// Socket address to listen on
#[arg(short, long)]
#[arg(short, long, default_value = "0.0.0.0:0")]
address: SocketAddr,
/// Socket address of the target service to forward traffic to
#[arg(short, long)]
target_address: SocketAddr,
/// The path to a PEM encoded private key
#[arg(long)]
Expand All @@ -55,12 +50,8 @@ enum CliCommand {
},
/// Retrieve the attested TLS certificate from a proxy server
GetTlsCert {
/// The socket address of the proxy server
#[arg(short, long)]
server_address: SocketAddr,
/// The domain name of the proxy server
#[arg(long)]
server_name: String,
/// The hostname:port or ip:port of the proxy server (port defaults to 443)
server: String,
},
}

Expand All @@ -71,8 +62,7 @@ async fn main() -> anyhow::Result<()> {
match cli.command {
CliCommand::Client {
address,
server_name,
server_address,
server,
private_key,
cert_chain,
} => {
Expand All @@ -92,8 +82,7 @@ async fn main() -> anyhow::Result<()> {
let client = ProxyClient::new(
tls_cert_and_chain,
address,
server_address,
server_name.try_into()?,
server,
NoAttestation,
MockAttestation,
)
Expand Down Expand Up @@ -132,12 +121,8 @@ async fn main() -> anyhow::Result<()> {
}
}
}
CliCommand::GetTlsCert {
server_address,
server_name,
} => {
let cert_chain =
get_tls_cert(server_address, server_name.try_into()?, MockAttestation).await?;
CliCommand::GetTlsCert { server } => {
let cert_chain = get_tls_cert(server, MockAttestation).await?;
println!("{}", certs_to_pem_string(&cert_chain)?);
}
}
Expand Down
27 changes: 16 additions & 11 deletions src/test_helpers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use rcgen::generate_simple_self_signed;
use std::{net::SocketAddr, sync::Arc};
use std::{
net::{IpAddr, SocketAddr},
sync::Arc,
};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio_rustls::rustls::{
Expand All @@ -10,16 +12,19 @@ use tokio_rustls::rustls::{

/// Helper to generate a self-signed certificate for testing
pub fn generate_certificate_chain(
name: String,
ip: IpAddr,
) -> (Vec<CertificateDer<'static>>, PrivateKeyDer<'static>) {
let subject_alt_names = vec![name];
let cert_key = generate_simple_self_signed(subject_alt_names)
.expect("Failed to generate self-signed certificate");

let certs = vec![CertificateDer::from(cert_key.cert)];
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(
cert_key.signing_key.serialize_der(),
));
let mut params = rcgen::CertificateParams::new(vec![]).unwrap();
params.subject_alt_names.push(rcgen::SanType::IpAddress(ip));
params
.distinguished_name
.push(rcgen::DnType::CommonName, ip.to_string());

let keypair = rcgen::KeyPair::generate().unwrap();
let cert = params.self_signed(&keypair).unwrap();

let certs = vec![CertificateDer::from(cert)];
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(keypair.serialize_der()));
(certs, key)
}

Expand Down