Skip to content

Commit 4a2c5bb

Browse files
committed
Attestation trait
1 parent 800da58 commit 4a2c5bb

File tree

3 files changed

+94
-58
lines changed

3 files changed

+94
-58
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
name = "attested-tls-proxy"
33
version = "0.1.0"
44
edition = "2024"
5+
license = "MIT OR Apache-2.0"
56

67
[dependencies]
7-
axum = "0.8.6"
88
tokio = { version = "1.48.0", features = ["full"]}
99
tokio-rustls = "0.26.4"
1010
sha2 = "0.10.9"
@@ -16,4 +16,5 @@ rustls-pemfile = "2.2.0"
1616

1717
[dev-dependencies]
1818
rcgen = "0.14.5"
19+
axum = "0.8.6"
1920
reqwest = { version = "0.12.23", default-features = false, features = ["rustls-tls-webpki-roots-no-provider"] }

src/lib.rs

Lines changed: 87 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use std::{net::SocketAddr, sync::Arc};
33
use thiserror::Error;
44
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
55
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
6-
use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName};
6+
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
7+
use tokio_rustls::rustls::RootCertStore;
78
use tokio_rustls::{
89
rustls::{ClientConfig, ServerConfig},
910
TlsAcceptor, TlsConnector,
@@ -13,6 +14,7 @@ use x509_parser::prelude::*;
1314
/// The label used when exporting key material from a TLS session
1415
const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding";
1516

17+
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
1618
pub struct ProxyServer {
1719
/// The certificate chain
1820
cert_chain: Vec<CertificateDer<'static>>,
@@ -22,10 +24,28 @@ pub struct ProxyServer {
2224
listener: TcpListener,
2325
/// The address of the target service we are proxying to
2426
target: SocketAddr,
27+
attestation_platform: MockAttestation,
2528
}
2629

2730
impl ProxyServer {
2831
pub async fn new(
32+
cert_chain: Vec<CertificateDer<'static>>,
33+
key: PrivateKeyDer<'static>,
34+
local: impl ToSocketAddrs,
35+
target: SocketAddr,
36+
) -> Self {
37+
let server_config = ServerConfig::builder()
38+
.with_no_client_auth()
39+
.with_single_cert(cert_chain.clone(), key)
40+
.expect("Failed to create rustls server config");
41+
42+
let server =
43+
Self::new_with_tls_config(cert_chain, server_config.into(), local, target).await;
44+
45+
server
46+
}
47+
48+
pub async fn new_with_tls_config(
2949
cert_chain: Vec<CertificateDer<'static>>,
3050
server_config: Arc<ServerConfig>,
3151
local: impl ToSocketAddrs,
@@ -39,6 +59,7 @@ impl ProxyServer {
3959
acceptor,
4060
listener,
4161
target,
62+
attestation_platform: MockAttestation,
4263
}
4364
}
4465

@@ -49,6 +70,7 @@ impl ProxyServer {
4970
let acceptor = self.acceptor.clone();
5071
let target = self.target;
5172
let cert_chain = self.cert_chain.clone();
73+
let attestation_platform = self.attestation_platform.clone();
5274
tokio::spawn(async move {
5375
let mut tls_stream = acceptor.accept(inbound).await.unwrap();
5476
let (_io, server_connection) = tls_stream.get_ref();
@@ -63,7 +85,7 @@ impl ProxyServer {
6385
.unwrap();
6486

6587
tls_stream
66-
.write_all(&create_attestation(&cert_chain, exporter))
88+
.write_all(&attestation_platform.create_attestation(&cert_chain, exporter))
6789
.await
6890
.unwrap();
6991

@@ -88,10 +110,27 @@ pub struct ProxyClient {
88110
target: SocketAddr,
89111
/// The subject name of the proxy server
90112
target_name: ServerName<'static>,
113+
attestation_platform: MockAttestation,
91114
}
92115

93116
impl ProxyClient {
94117
pub async fn new(
118+
address: impl ToSocketAddrs,
119+
server_address: SocketAddr,
120+
server_name: ServerName<'static>,
121+
) -> Self {
122+
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
123+
let client_config = ClientConfig::builder()
124+
.with_root_certificates(root_store)
125+
.with_no_client_auth();
126+
127+
let client =
128+
Self::new_with_tls_config(client_config.into(), address, server_address, server_name)
129+
.await;
130+
client
131+
}
132+
133+
pub async fn new_with_tls_config(
95134
client_config: Arc<ClientConfig>,
96135
local: impl ToSocketAddrs,
97136
target: SocketAddr,
@@ -103,8 +142,9 @@ impl ProxyClient {
103142
Self {
104143
connector,
105144
listener,
106-
target,
145+
target: target.into(),
107146
target_name,
147+
attestation_platform: MockAttestation,
108148
}
109149
}
110150

@@ -114,6 +154,7 @@ impl ProxyClient {
114154
let connector = self.connector.clone();
115155
let target_name = self.target_name.clone();
116156
let target = self.target;
157+
let attestation_platform = self.attestation_platform.clone();
117158

118159
tokio::spawn(async move {
119160
let out = TcpStream::connect(target).await.unwrap();
@@ -134,7 +175,7 @@ impl ProxyClient {
134175
let mut buf = [0; 64];
135176
tls_stream.read_exact(&mut buf).await.unwrap();
136177

137-
if !verify_attestation(buf, &cert_chain, exporter) {
178+
if !attestation_platform.verify_attestation(buf, &cert_chain, exporter) {
138179
panic!("Cannot verify attestation");
139180
};
140181

@@ -150,27 +191,44 @@ impl ProxyClient {
150191
}
151192
}
152193

153-
/// Mocks creating an attestation
154-
fn create_attestation(cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec<u8> {
155-
let mut quote_input = [0u8; 64];
156-
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap();
157-
quote_input[..32].copy_from_slice(&pki_hash);
158-
quote_input[32..].copy_from_slice(&exporter);
159-
quote_input.to_vec()
194+
pub trait AttestationPlatform {
195+
fn create_attestation(&self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec<u8>;
196+
197+
fn verify_attestation(
198+
&self,
199+
input: [u8; 64],
200+
cert_chain: &[CertificateDer<'_>],
201+
exporter: [u8; 32],
202+
) -> bool;
160203
}
161204

162-
/// Mocks verifying an attestation
163-
fn verify_attestation(
164-
input: [u8; 64],
165-
cert_chain: &[CertificateDer<'_>],
166-
exporter: [u8; 32],
167-
) -> bool {
168-
let mut quote_input = [0u8; 64];
169-
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap();
170-
quote_input[..32].copy_from_slice(&pki_hash);
171-
quote_input[32..].copy_from_slice(&exporter);
172-
173-
input == quote_input
205+
#[derive(Clone)]
206+
struct MockAttestation;
207+
208+
impl AttestationPlatform for MockAttestation {
209+
/// Mocks creating an attestation
210+
fn create_attestation(&self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec<u8> {
211+
let mut quote_input = [0u8; 64];
212+
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap();
213+
quote_input[..32].copy_from_slice(&pki_hash);
214+
quote_input[32..].copy_from_slice(&exporter);
215+
quote_input.to_vec()
216+
}
217+
218+
/// Mocks verifying an attestation
219+
fn verify_attestation(
220+
&self,
221+
input: [u8; 64],
222+
cert_chain: &[CertificateDer<'_>],
223+
exporter: [u8; 32],
224+
) -> bool {
225+
let mut quote_input = [0u8; 64];
226+
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap();
227+
quote_input[..32].copy_from_slice(&pki_hash);
228+
quote_input[32..].copy_from_slice(&exporter);
229+
230+
input == quote_input
231+
}
174232
}
175233

176234
/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate
@@ -284,14 +342,15 @@ mod tests {
284342
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
285343

286344
let proxy_server =
287-
ProxyServer::new(cert_chain, server_config, "127.0.0.1:0", target_addr).await;
345+
ProxyServer::new_with_tls_config(cert_chain, server_config, "127.0.0.1:0", target_addr)
346+
.await;
288347
let proxy_addr = proxy_server.listener.local_addr().unwrap();
289348

290349
tokio::spawn(async move {
291350
proxy_server.accept().await.unwrap();
292351
});
293352

294-
let proxy_client = ProxyClient::new(
353+
let proxy_client = ProxyClient::new_with_tls_config(
295354
client_config,
296355
"127.0.0.1:0",
297356
proxy_addr,
@@ -323,14 +382,15 @@ mod tests {
323382
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
324383

325384
let proxy_server =
326-
ProxyServer::new(cert_chain, server_config, "127.0.0.1:0", target_addr).await;
385+
ProxyServer::new_with_tls_config(cert_chain, server_config, "127.0.0.1:0", target_addr)
386+
.await;
327387
let proxy_server_addr = proxy_server.listener.local_addr().unwrap();
328388

329389
tokio::spawn(async move {
330390
proxy_server.accept().await.unwrap();
331391
});
332392

333-
let proxy_client = ProxyClient::new(
393+
let proxy_client = ProxyClient::new_with_tls_config(
334394
client_config,
335395
"127.0.0.1:0",
336396
proxy_server_addr,

src/main.rs

Lines changed: 5 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
use clap::{Parser, Subcommand};
22
use std::{fs::File, net::SocketAddr};
3-
use tokio_rustls::rustls::{
4-
pki_types::{CertificateDer, PrivateKeyDer},
5-
ClientConfig, RootCertStore, ServerConfig,
6-
};
3+
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
74

85
use attested_tls_proxy::{ProxyClient, ProxyServer};
96

107
#[derive(Parser, Debug, Clone)]
118
#[clap(version, about, long_about = None)]
12-
#[clap(about = "Peer to peer filesharing")]
139
struct Cli {
1410
#[clap(subcommand)]
1511
command: CliCommand,
@@ -43,19 +39,9 @@ async fn main() {
4339
server_name,
4440
server_address,
4541
} => {
46-
let root_store =
47-
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
48-
let client_config = ClientConfig::builder()
49-
.with_root_certificates(root_store)
50-
.with_no_client_auth();
51-
52-
let client = ProxyClient::new(
53-
client_config.into(),
54-
cli.address,
55-
server_address,
56-
server_name.try_into().unwrap(),
57-
)
58-
.await;
42+
let client =
43+
ProxyClient::new(cli.address, server_address, server_name.try_into().unwrap())
44+
.await;
5945

6046
loop {
6147
client.accept().await.unwrap();
@@ -64,18 +50,7 @@ async fn main() {
6450
CliCommand::Server { client_address } => {
6551
let cert_chain = load_certs_pem("certs.pem").unwrap();
6652
let key = load_private_key_pem("key.pem");
67-
let server_config = ServerConfig::builder()
68-
.with_no_client_auth()
69-
.with_single_cert(cert_chain.clone(), key)
70-
.expect("Failed to create rustls server config");
71-
72-
let server = ProxyServer::new(
73-
cert_chain,
74-
server_config.into(),
75-
cli.address,
76-
client_address,
77-
)
78-
.await;
53+
let server = ProxyServer::new(cert_chain, key, cli.address, client_address).await;
7954

8055
loop {
8156
server.accept().await.unwrap();

0 commit comments

Comments
 (0)