Skip to content

Commit 32ee1f4

Browse files
committed
Handle client auth with CLI
1 parent 097851e commit 32ee1f4

File tree

2 files changed

+87
-24
lines changed

2 files changed

+87
-24
lines changed

src/lib.rs

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod attestation;
22

33
pub use attestation::{AttestationPlatform, MockAttestation, NoAttestation};
4+
use tokio_rustls::rustls::server::WebPkiClientVerifier;
45

56
#[cfg(test)]
67
mod test_helpers;
@@ -18,6 +19,11 @@ use tokio_rustls::{
1819
/// The label used when exporting key material from a TLS session
1920
const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding";
2021

22+
pub struct TlsCertAndKey {
23+
pub cert_chain: Vec<CertificateDer<'static>>,
24+
pub key: PrivateKeyDer<'static>,
25+
}
26+
2127
struct Proxy<L, R>
2228
where
2329
L: AttestationPlatform,
@@ -48,20 +54,36 @@ where
4854

4955
impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
5056
pub async fn new(
51-
cert_chain: Vec<CertificateDer<'static>>,
52-
key: PrivateKeyDer<'static>,
57+
cert_and_key: TlsCertAndKey,
5358
local: impl ToSocketAddrs,
5459
target: SocketAddr,
5560
local_attestation_platform: L,
5661
remote_attestation_platform: R,
62+
client_auth: bool,
5763
) -> Self {
58-
let server_config = ServerConfig::builder()
59-
.with_no_client_auth()
60-
.with_single_cert(cert_chain.clone(), key)
61-
.expect("Failed to create rustls server config");
64+
if remote_attestation_platform.is_cvm() && !client_auth {
65+
panic!("Client auth is required when the client is running in a CVM");
66+
}
67+
68+
let server_config = if client_auth {
69+
let root_store =
70+
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
71+
let verifier = WebPkiClientVerifier::builder(Arc::new(root_store))
72+
.build()
73+
.expect("invalid client verifier");
74+
ServerConfig::builder()
75+
.with_client_cert_verifier(verifier)
76+
.with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)
77+
.expect("Failed to create rustls server config")
78+
} else {
79+
ServerConfig::builder()
80+
.with_no_client_auth()
81+
.with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)
82+
.expect("Failed to create rustls server config")
83+
};
6284

6385
Self::new_with_tls_config(
64-
cert_chain,
86+
cert_and_key.cert_chain,
6587
server_config.into(),
6688
local,
6789
target,
@@ -187,17 +209,28 @@ where
187209

188210
impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
189211
pub async fn new(
212+
cert_and_key: Option<TlsCertAndKey>,
190213
address: impl ToSocketAddrs,
191214
server_address: SocketAddr,
192215
server_name: ServerName<'static>,
193216
local_attestation_platform: L,
194217
remote_attestation_platform: R,
195-
cert_chain: Option<Vec<CertificateDer<'static>>>,
196218
) -> Self {
197219
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
198-
let client_config = ClientConfig::builder()
199-
.with_root_certificates(root_store)
200-
.with_no_client_auth();
220+
221+
let client_config = if let Some(ref cert_and_key) = cert_and_key {
222+
ClientConfig::builder()
223+
.with_root_certificates(root_store)
224+
.with_client_auth_cert(
225+
cert_and_key.cert_chain.clone(),
226+
cert_and_key.key.clone_key(),
227+
)
228+
.unwrap()
229+
} else {
230+
ClientConfig::builder()
231+
.with_root_certificates(root_store)
232+
.with_no_client_auth()
233+
};
201234

202235
Self::new_with_tls_config(
203236
client_config.into(),
@@ -206,7 +239,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
206239
server_name,
207240
local_attestation_platform,
208241
remote_attestation_platform,
209-
cert_chain,
242+
cert_and_key.map(|c| c.cert_chain),
210243
)
211244
.await
212245
}

src/main.rs

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use clap::{Parser, Subcommand};
2-
use std::{fs::File, net::SocketAddr};
2+
use std::{fs::File, net::SocketAddr, path::PathBuf};
33
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
44

5-
use attested_tls_proxy::{MockAttestation, NoAttestation, ProxyClient, ProxyServer};
5+
use attested_tls_proxy::{MockAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey};
66

77
#[derive(Parser, Debug, Clone)]
88
#[clap(version, about, long_about = None)]
@@ -22,11 +22,26 @@ enum CliCommand {
2222
server_address: SocketAddr,
2323
#[arg(long)]
2424
server_name: String,
25+
/// The path to a PEM encoded private key for client authentication
26+
#[arg(long)]
27+
private_key: Option<PathBuf>,
28+
/// The path to a PEM encoded certificate chain for client authentication
29+
#[arg(long)]
30+
cert_chain: Option<PathBuf>,
2531
},
2632
/// Run a proxy server
2733
Server {
34+
/// Socket address of the target service to forward traffic to
2835
#[arg(short, long)]
29-
client_address: SocketAddr,
36+
target_address: SocketAddr,
37+
/// The path to a PEM encoded private key
38+
#[arg(long)]
39+
private_key: PathBuf,
40+
/// The path to a PEM encoded certificate chain
41+
#[arg(long)]
42+
cert_chain: PathBuf,
43+
#[arg(long)]
44+
client_auth: bool,
3045
},
3146
}
3247

@@ -38,34 +53,43 @@ async fn main() {
3853
CliCommand::Client {
3954
server_name,
4055
server_address,
56+
private_key,
57+
cert_chain,
4158
} => {
59+
let tls_cert_and_chain = private_key
60+
.map(|private_key| load_tls_cert_and_key(cert_chain.unwrap(), private_key));
61+
4262
let client = ProxyClient::new(
63+
tls_cert_and_chain,
4364
cli.address,
4465
server_address,
4566
server_name.try_into().unwrap(),
4667
NoAttestation,
4768
MockAttestation,
48-
None,
4969
)
5070
.await;
5171

5272
loop {
5373
client.accept().await.unwrap();
5474
}
5575
}
56-
CliCommand::Server { client_address } => {
57-
let cert_chain = load_certs_pem("certs.pem").unwrap();
58-
let key = load_private_key_pem("key.pem");
76+
CliCommand::Server {
77+
target_address,
78+
private_key,
79+
cert_chain,
80+
client_auth,
81+
} => {
82+
let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key);
5983
let local_attestation = MockAttestation;
6084
let remote_attestation = NoAttestation;
6185

6286
let server = ProxyServer::new(
63-
cert_chain,
64-
key,
87+
tls_cert_and_chain,
6588
cli.address,
66-
client_address,
89+
target_address,
6790
local_attestation,
6891
remote_attestation,
92+
client_auth,
6993
)
7094
.await;
7195

@@ -76,15 +100,21 @@ async fn main() {
76100
}
77101
}
78102

79-
pub fn load_certs_pem(path: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
103+
fn load_tls_cert_and_key(cert_chain: PathBuf, private_key: PathBuf) -> TlsCertAndKey {
104+
let key = load_private_key_pem(private_key);
105+
let cert_chain = load_certs_pem(cert_chain).unwrap();
106+
TlsCertAndKey { key, cert_chain }
107+
}
108+
109+
pub fn load_certs_pem(path: PathBuf) -> std::io::Result<Vec<CertificateDer<'static>>> {
80110
Ok(
81111
rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?))
82112
.map(|res| res.unwrap())
83113
.collect(),
84114
)
85115
}
86116

87-
pub fn load_private_key_pem(path: &str) -> PrivateKeyDer<'static> {
117+
pub fn load_private_key_pem(path: PathBuf) -> PrivateKeyDer<'static> {
88118
let mut reader = std::io::BufReader::new(File::open(path).unwrap());
89119

90120
// Tries to read the key as PKCS#8, PKCS#1, or SEC1

0 commit comments

Comments
 (0)