@@ -3,7 +3,8 @@ use std::{net::SocketAddr, sync::Arc};
33use thiserror:: Error ;
44use tokio:: io:: { self , AsyncReadExt , AsyncWriteExt } ;
55use tokio:: net:: { TcpListener , TcpStream , ToSocketAddrs } ;
6- use tokio_rustls:: rustls:: pki_types:: { CertificateDer , ServerName } ;
6+ use tokio_rustls:: rustls:: pki_types:: { CertificateDer , PrivateKeyDer , ServerName } ;
7+ use tokio_rustls:: rustls:: RootCertStore ;
78use tokio_rustls:: {
89 rustls:: { ClientConfig , ServerConfig } ,
910 TlsAcceptor , TlsConnector ,
@@ -13,6 +14,7 @@ use x509_parser::prelude::*;
1314/// The label used when exporting key material from a TLS session
1415const EXPORTER_LABEL : & [ u8 ; 24 ] = b"EXPORTER-Channel-Binding" ;
1516
17+ /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
1618pub struct ProxyServer {
1719 /// The certificate chain
1820 cert_chain : Vec < CertificateDer < ' static > > ,
@@ -22,10 +24,28 @@ pub struct ProxyServer {
2224 listener : TcpListener ,
2325 /// The address of the target service we are proxying to
2426 target : SocketAddr ,
27+ attestation_platform : MockAttestation ,
2528}
2629
2730impl ProxyServer {
2831 pub async fn new (
32+ cert_chain : Vec < CertificateDer < ' static > > ,
33+ key : PrivateKeyDer < ' static > ,
34+ local : impl ToSocketAddrs ,
35+ target : SocketAddr ,
36+ ) -> Self {
37+ let server_config = ServerConfig :: builder ( )
38+ . with_no_client_auth ( )
39+ . with_single_cert ( cert_chain. clone ( ) , key)
40+ . expect ( "Failed to create rustls server config" ) ;
41+
42+ let server =
43+ Self :: new_with_tls_config ( cert_chain, server_config. into ( ) , local, target) . await ;
44+
45+ server
46+ }
47+
48+ pub async fn new_with_tls_config (
2949 cert_chain : Vec < CertificateDer < ' static > > ,
3050 server_config : Arc < ServerConfig > ,
3151 local : impl ToSocketAddrs ,
@@ -39,6 +59,7 @@ impl ProxyServer {
3959 acceptor,
4060 listener,
4161 target,
62+ attestation_platform : MockAttestation ,
4263 }
4364 }
4465
@@ -49,6 +70,7 @@ impl ProxyServer {
4970 let acceptor = self . acceptor . clone ( ) ;
5071 let target = self . target ;
5172 let cert_chain = self . cert_chain . clone ( ) ;
73+ let attestation_platform = self . attestation_platform . clone ( ) ;
5274 tokio:: spawn ( async move {
5375 let mut tls_stream = acceptor. accept ( inbound) . await . unwrap ( ) ;
5476 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
@@ -63,7 +85,7 @@ impl ProxyServer {
6385 . unwrap ( ) ;
6486
6587 tls_stream
66- . write_all ( & create_attestation ( & cert_chain, exporter) )
88+ . write_all ( & attestation_platform . create_attestation ( & cert_chain, exporter) )
6789 . await
6890 . unwrap ( ) ;
6991
@@ -88,10 +110,27 @@ pub struct ProxyClient {
88110 target : SocketAddr ,
89111 /// The subject name of the proxy server
90112 target_name : ServerName < ' static > ,
113+ attestation_platform : MockAttestation ,
91114}
92115
93116impl ProxyClient {
94117 pub async fn new (
118+ address : impl ToSocketAddrs ,
119+ server_address : SocketAddr ,
120+ server_name : ServerName < ' static > ,
121+ ) -> Self {
122+ let root_store = RootCertStore :: from_iter ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
123+ let client_config = ClientConfig :: builder ( )
124+ . with_root_certificates ( root_store)
125+ . with_no_client_auth ( ) ;
126+
127+ let client =
128+ Self :: new_with_tls_config ( client_config. into ( ) , address, server_address, server_name)
129+ . await ;
130+ client
131+ }
132+
133+ pub async fn new_with_tls_config (
95134 client_config : Arc < ClientConfig > ,
96135 local : impl ToSocketAddrs ,
97136 target : SocketAddr ,
@@ -103,8 +142,9 @@ impl ProxyClient {
103142 Self {
104143 connector,
105144 listener,
106- target,
145+ target : target . into ( ) ,
107146 target_name,
147+ attestation_platform : MockAttestation ,
108148 }
109149 }
110150
@@ -114,6 +154,7 @@ impl ProxyClient {
114154 let connector = self . connector . clone ( ) ;
115155 let target_name = self . target_name . clone ( ) ;
116156 let target = self . target ;
157+ let attestation_platform = self . attestation_platform . clone ( ) ;
117158
118159 tokio:: spawn ( async move {
119160 let out = TcpStream :: connect ( target) . await . unwrap ( ) ;
@@ -134,7 +175,7 @@ impl ProxyClient {
134175 let mut buf = [ 0 ; 64 ] ;
135176 tls_stream. read_exact ( & mut buf) . await . unwrap ( ) ;
136177
137- if !verify_attestation ( buf, & cert_chain, exporter) {
178+ if !attestation_platform . verify_attestation ( buf, & cert_chain, exporter) {
138179 panic ! ( "Cannot verify attestation" ) ;
139180 } ;
140181
@@ -150,27 +191,44 @@ impl ProxyClient {
150191 }
151192}
152193
153- /// Mocks creating an attestation
154- fn create_attestation ( cert_chain : & [ CertificateDer < ' _ > ] , exporter : [ u8 ; 32 ] ) -> Vec < u8 > {
155- let mut quote_input = [ 0u8 ; 64 ] ;
156- let pki_hash = get_pki_hash_from_certificate_chain ( cert_chain) . unwrap ( ) ;
157- quote_input[ ..32 ] . copy_from_slice ( & pki_hash) ;
158- quote_input[ 32 ..] . copy_from_slice ( & exporter) ;
159- quote_input. to_vec ( )
194+ pub trait AttestationPlatform {
195+ fn create_attestation ( & self , cert_chain : & [ CertificateDer < ' _ > ] , exporter : [ u8 ; 32 ] ) -> Vec < u8 > ;
196+
197+ fn verify_attestation (
198+ & self ,
199+ input : [ u8 ; 64 ] ,
200+ cert_chain : & [ CertificateDer < ' _ > ] ,
201+ exporter : [ u8 ; 32 ] ,
202+ ) -> bool ;
160203}
161204
162- /// Mocks verifying an attestation
163- fn verify_attestation (
164- input : [ u8 ; 64 ] ,
165- cert_chain : & [ CertificateDer < ' _ > ] ,
166- exporter : [ u8 ; 32 ] ,
167- ) -> bool {
168- let mut quote_input = [ 0u8 ; 64 ] ;
169- let pki_hash = get_pki_hash_from_certificate_chain ( cert_chain) . unwrap ( ) ;
170- quote_input[ ..32 ] . copy_from_slice ( & pki_hash) ;
171- quote_input[ 32 ..] . copy_from_slice ( & exporter) ;
172-
173- input == quote_input
205+ #[ derive( Clone ) ]
206+ struct MockAttestation ;
207+
208+ impl AttestationPlatform for MockAttestation {
209+ /// Mocks creating an attestation
210+ fn create_attestation ( & self , cert_chain : & [ CertificateDer < ' _ > ] , exporter : [ u8 ; 32 ] ) -> Vec < u8 > {
211+ let mut quote_input = [ 0u8 ; 64 ] ;
212+ let pki_hash = get_pki_hash_from_certificate_chain ( cert_chain) . unwrap ( ) ;
213+ quote_input[ ..32 ] . copy_from_slice ( & pki_hash) ;
214+ quote_input[ 32 ..] . copy_from_slice ( & exporter) ;
215+ quote_input. to_vec ( )
216+ }
217+
218+ /// Mocks verifying an attestation
219+ fn verify_attestation (
220+ & self ,
221+ input : [ u8 ; 64 ] ,
222+ cert_chain : & [ CertificateDer < ' _ > ] ,
223+ exporter : [ u8 ; 32 ] ,
224+ ) -> bool {
225+ let mut quote_input = [ 0u8 ; 64 ] ;
226+ let pki_hash = get_pki_hash_from_certificate_chain ( cert_chain) . unwrap ( ) ;
227+ quote_input[ ..32 ] . copy_from_slice ( & pki_hash) ;
228+ quote_input[ 32 ..] . copy_from_slice ( & exporter) ;
229+
230+ input == quote_input
231+ }
174232}
175233
176234/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate
@@ -284,14 +342,15 @@ mod tests {
284342 let ( server_config, client_config) = generate_tls_config ( cert_chain. clone ( ) , private_key) ;
285343
286344 let proxy_server =
287- ProxyServer :: new ( cert_chain, server_config, "127.0.0.1:0" , target_addr) . await ;
345+ ProxyServer :: new_with_tls_config ( cert_chain, server_config, "127.0.0.1:0" , target_addr)
346+ . await ;
288347 let proxy_addr = proxy_server. listener . local_addr ( ) . unwrap ( ) ;
289348
290349 tokio:: spawn ( async move {
291350 proxy_server. accept ( ) . await . unwrap ( ) ;
292351 } ) ;
293352
294- let proxy_client = ProxyClient :: new (
353+ let proxy_client = ProxyClient :: new_with_tls_config (
295354 client_config,
296355 "127.0.0.1:0" ,
297356 proxy_addr,
@@ -323,14 +382,15 @@ mod tests {
323382 let ( server_config, client_config) = generate_tls_config ( cert_chain. clone ( ) , private_key) ;
324383
325384 let proxy_server =
326- ProxyServer :: new ( cert_chain, server_config, "127.0.0.1:0" , target_addr) . await ;
385+ ProxyServer :: new_with_tls_config ( cert_chain, server_config, "127.0.0.1:0" , target_addr)
386+ . await ;
327387 let proxy_server_addr = proxy_server. listener . local_addr ( ) . unwrap ( ) ;
328388
329389 tokio:: spawn ( async move {
330390 proxy_server. accept ( ) . await . unwrap ( ) ;
331391 } ) ;
332392
333- let proxy_client = ProxyClient :: new (
393+ let proxy_client = ProxyClient :: new_with_tls_config (
334394 client_config,
335395 "127.0.0.1:0" ,
336396 proxy_server_addr,
0 commit comments