@@ -12,9 +12,8 @@ use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
1212use x509_parser:: parse_x509_certificate;
1313
1414use std:: num:: TryFromIntError ;
15- use std:: { net:: SocketAddr , sync:: Arc } ;
16- use tokio:: io:: { AsyncReadExt , AsyncWriteExt } ;
17- use tokio:: net:: { TcpListener , TcpStream , ToSocketAddrs } ;
15+ use std:: sync:: Arc ;
16+ use tokio:: io:: { AsyncRead , AsyncReadExt , AsyncWrite , AsyncWriteExt } ;
1817use tokio_rustls:: rustls:: pki_types:: { CertificateDer , PrivateKeyDer , ServerName } ;
1918use tokio_rustls:: rustls:: RootCertStore ;
2019use tokio_rustls:: {
@@ -38,16 +37,14 @@ pub struct TlsCertAndKey {
3837 pub key : PrivateKeyDer < ' static > ,
3938}
4039
41- /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
40+ /// A TLS server which makes an attestation exchange following the TLS handshake
4241#[ derive( Clone ) ]
4342pub struct AttestedTlsServer {
44- /// The underlying TCP listener
45- pub listener : Arc < TcpListener > ,
4643 /// Quote generation type to use (including none)
4744 attestation_generator : AttestationGenerator ,
4845 /// Verifier for remote attestation (including none)
4946 attestation_verifier : AttestationVerifier ,
50- /// The certificate chain
47+ /// The TLS certificate chain
5148 cert_chain : Vec < CertificateDer < ' static > > ,
5249 /// For accepting TLS connections
5350 acceptor : TlsAcceptor ,
@@ -56,7 +53,6 @@ pub struct AttestedTlsServer {
5653impl AttestedTlsServer {
5754 pub async fn new (
5855 cert_and_key : TlsCertAndKey ,
59- local : impl ToSocketAddrs ,
6056 attestation_generator : AttestationGenerator ,
6157 attestation_verifier : AttestationVerifier ,
6258 client_auth : bool ,
@@ -83,7 +79,6 @@ impl AttestedTlsServer {
8379 Self :: new_with_tls_config (
8480 cert_and_key. cert_chain ,
8581 server_config. into ( ) ,
86- local,
8782 attestation_generator,
8883 attestation_verifier,
8984 )
@@ -96,55 +91,36 @@ impl AttestedTlsServer {
9691 pub ( crate ) async fn new_with_tls_config (
9792 cert_chain : Vec < CertificateDer < ' static > > ,
9893 server_config : Arc < ServerConfig > ,
99- local : impl ToSocketAddrs ,
10094 attestation_generator : AttestationGenerator ,
10195 attestation_verifier : AttestationVerifier ,
10296 ) -> Result < Self , AttestedTlsError > {
10397 let acceptor = tokio_rustls:: TlsAcceptor :: from ( server_config) ;
104- let listener = TcpListener :: bind ( local) . await ?;
10598
10699 Ok ( Self {
107- listener : listener. into ( ) ,
108100 attestation_generator,
109101 attestation_verifier,
110102 acceptor,
111103 cert_chain,
112104 } )
113105 }
114106
115- /// Accept an incoming connection and do an attestation exchange
116- pub async fn accept (
117- & self ,
118- ) -> Result <
119- (
120- tokio_rustls:: server:: TlsStream < tokio:: net:: TcpStream > ,
121- Option < MultiMeasurements > ,
122- AttestationType ,
123- ) ,
124- AttestedTlsError ,
125- > {
126- let ( inbound, _client_addr) = self . listener . accept ( ) . await ?;
127-
128- self . handle_connection ( inbound) . await
129- }
130-
131- /// Helper to get the socket address of the underlying TCP listener
132- pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
133- self . listener . local_addr ( )
134- }
135-
136107 /// Handle an incoming connection from a proxy-client
137- pub async fn handle_connection (
108+ ///
109+ /// This is transport agnostic and will work with any asynchronous stream
110+ pub async fn handle_connection < IO > (
138111 & self ,
139- inbound : TcpStream ,
112+ inbound : IO ,
140113 ) -> Result <
141114 (
142- tokio_rustls:: server:: TlsStream < tokio :: net :: TcpStream > ,
115+ tokio_rustls:: server:: TlsStream < IO > ,
143116 Option < MultiMeasurements > ,
144117 AttestationType ,
145118 ) ,
146119 AttestedTlsError ,
147- > {
120+ >
121+ where
122+ IO : AsyncRead + AsyncWrite + Unpin ,
123+ {
148124 tracing:: debug!( "attested-tls-server accepted connection" ) ;
149125
150126 // Do TLS handshake
@@ -296,23 +272,29 @@ impl AttestedTlsClient {
296272 } )
297273 }
298274
299- /// Connect to an attested-tls-server, do TLS handshake and attestation exchange
300- pub async fn connect (
275+ /// Given a connection to an attested TLS server, do a TLS handshake and attestation exchange, and return the TLS
276+ /// stream together with measurement details
277+ ///
278+ /// This is transport agnostic and will work with any asynchronous stream
279+ pub async fn connect < IO > (
301280 & self ,
302281 target : & str ,
282+ outbound : IO ,
303283 ) -> Result <
304284 (
305- tokio_rustls:: client:: TlsStream < tokio :: net :: TcpStream > ,
285+ tokio_rustls:: client:: TlsStream < IO > ,
306286 Option < MultiMeasurements > ,
307287 AttestationType ,
308288 ) ,
309289 AttestedTlsError ,
310- > {
311- // Make a TCP client connection and TLS handshake
312- let out = TcpStream :: connect ( & target) . await ?;
290+ >
291+ where
292+ IO : AsyncRead + AsyncWrite + Unpin ,
293+ {
294+ // Make a TLS handshake with the given connection
313295 let mut tls_stream = self
314296 . connector
315- . connect ( server_name_from_host ( target) ?, out )
297+ . connect ( server_name_from_host ( target) ?, outbound )
316298 . await ?;
317299
318300 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
@@ -374,12 +356,29 @@ impl AttestedTlsClient {
374356 Ok ( ( tls_stream, measurements, remote_attestation_type) )
375357 }
376358
377- /// Connect to an attested TLS server, retrieve the remote TLS certificate and return it
359+ /// Make a TCP connection, do a TLS handshake and attestation exchange, and return the TLS
360+ /// stream together with measurement details
361+ pub async fn connect_tcp (
362+ & self ,
363+ target : & str ,
364+ ) -> Result <
365+ (
366+ tokio_rustls:: client:: TlsStream < tokio:: net:: TcpStream > ,
367+ Option < MultiMeasurements > ,
368+ AttestationType ,
369+ ) ,
370+ AttestedTlsError ,
371+ > {
372+ let out = tokio:: net:: TcpStream :: connect ( & target) . await ?;
373+ self . connect ( target, out) . await
374+ }
375+
376+ /// Connect to an attested TLS server using TCP, retrieve the remote TLS certificate and return it
378377 pub async fn get_tls_cert (
379378 & self ,
380379 server_name : & str ,
381380 ) -> Result < Vec < CertificateDer < ' static > > , AttestedTlsError > {
382- let ( mut tls_stream, _, _) = self . connect ( server_name) . await ?;
381+ let ( mut tls_stream, _, _) = self . connect_tcp ( server_name) . await ?;
383382
384383 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
385384
0 commit comments