11mod attestation;
22
3+ use attestation:: AttestationError ;
34pub use attestation:: { AttestationPlatform , MockAttestation , NoAttestation } ;
45use thiserror:: Error ;
56use tokio_rustls:: rustls:: server:: { VerifierBuilderError , WebPkiClientVerifier } ;
67
78#[ cfg( test) ]
89mod test_helpers;
910
11+ use std:: num:: TryFromIntError ;
1012use std:: { net:: SocketAddr , sync:: Arc } ;
1113use tokio:: io:: { self , AsyncReadExt , AsyncWriteExt } ;
1214use tokio:: net:: { TcpListener , TcpStream , ToSocketAddrs } ;
@@ -129,58 +131,18 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
129131 let local_attestation_platform = self . inner . local_attestation_platform . clone ( ) ;
130132 let remote_attestation_platform = self . inner . remote_attestation_platform . clone ( ) ;
131133 tokio:: spawn ( async move {
132- let mut tls_stream = acceptor. accept ( inbound) . await . unwrap ( ) ;
133- let ( _io, connection) = tls_stream. get_ref ( ) ;
134-
135- let mut exporter = [ 0u8 ; 32 ] ;
136- connection
137- . export_keying_material (
138- & mut exporter,
139- EXPORTER_LABEL ,
140- None , // context
141- )
142- . unwrap ( ) ;
143-
144- let remote_cert_chain = connection. peer_certificates ( ) . map ( |c| c. to_owned ( ) ) ;
145-
146- let attestation = if local_attestation_platform. is_cvm ( ) {
147- local_attestation_platform
148- . create_attestation ( & cert_chain, exporter)
149- . unwrap ( )
150- } else {
151- Vec :: new ( )
152- } ;
153-
154- let attestation_length_prefix = length_prefix ( & attestation) ;
155-
156- tls_stream
157- . write_all ( & attestation_length_prefix)
158- . await
159- . unwrap ( ) ;
160-
161- tls_stream. write_all ( & attestation) . await . unwrap ( ) ;
162-
163- let mut length_bytes = [ 0 ; 4 ] ;
164- tls_stream. read_exact ( & mut length_bytes) . await . unwrap ( ) ;
165- let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) . unwrap ( ) ;
166-
167- let mut buf = vec ! [ 0 ; length] ;
168- tls_stream. read_exact ( & mut buf) . await . unwrap ( ) ;
169-
170- if remote_attestation_platform. is_cvm ( ) {
171- remote_attestation_platform
172- . verify_attestation ( buf, & remote_cert_chain. unwrap ( ) , exporter)
173- . unwrap ( ) ;
134+ if let Err ( err) = Self :: handle_connection (
135+ inbound,
136+ acceptor,
137+ target,
138+ cert_chain,
139+ local_attestation_platform,
140+ remote_attestation_platform,
141+ )
142+ . await
143+ {
144+ eprintln ! ( "Failed to handle connection: {err}" ) ;
174145 }
175-
176- let outbound = TcpStream :: connect ( target) . await . unwrap ( ) ;
177-
178- let ( mut inbound_reader, mut inbound_writer) = tokio:: io:: split ( tls_stream) ;
179- let ( mut outbound_reader, mut outbound_writer) = outbound. into_split ( ) ;
180-
181- let client_to_server = tokio:: io:: copy ( & mut inbound_reader, & mut outbound_writer) ;
182- let server_to_client = tokio:: io:: copy ( & mut outbound_reader, & mut inbound_writer) ;
183- tokio:: try_join!( client_to_server, server_to_client) . unwrap ( ) ;
184146 } ) ;
185147
186148 Ok ( ( ) )
@@ -189,6 +151,64 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
189151 pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
190152 self . inner . listener . local_addr ( )
191153 }
154+
155+ async fn handle_connection (
156+ inbound : TcpStream ,
157+ acceptor : TlsAcceptor ,
158+ target : SocketAddr ,
159+ cert_chain : Vec < CertificateDer < ' static > > ,
160+ local_attestation_platform : L ,
161+ remote_attestation_platform : R ,
162+ ) -> Result < ( ) , ProxyError > {
163+ let mut tls_stream = acceptor. accept ( inbound) . await ?;
164+ let ( _io, connection) = tls_stream. get_ref ( ) ;
165+
166+ let mut exporter = [ 0u8 ; 32 ] ;
167+ connection. export_keying_material (
168+ & mut exporter,
169+ EXPORTER_LABEL ,
170+ None , // context
171+ ) ?;
172+
173+ let remote_cert_chain = connection. peer_certificates ( ) . map ( |c| c. to_owned ( ) ) ;
174+
175+ let attestation = if local_attestation_platform. is_cvm ( ) {
176+ local_attestation_platform. create_attestation ( & cert_chain, exporter) ?
177+ } else {
178+ Vec :: new ( )
179+ } ;
180+
181+ let attestation_length_prefix = length_prefix ( & attestation) ;
182+
183+ tls_stream. write_all ( & attestation_length_prefix) . await ?;
184+
185+ tls_stream. write_all ( & attestation) . await ?;
186+
187+ let mut length_bytes = [ 0 ; 4 ] ;
188+ tls_stream. read_exact ( & mut length_bytes) . await ?;
189+ let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) ?;
190+
191+ let mut buf = vec ! [ 0 ; length] ;
192+ tls_stream. read_exact ( & mut buf) . await ?;
193+
194+ if remote_attestation_platform. is_cvm ( ) {
195+ remote_attestation_platform. verify_attestation (
196+ buf,
197+ & remote_cert_chain. ok_or ( ProxyError :: NoClientAuth ) ?,
198+ exporter,
199+ ) ?;
200+ }
201+
202+ let outbound = TcpStream :: connect ( target) . await ?;
203+
204+ let ( mut inbound_reader, mut inbound_writer) = tokio:: io:: split ( tls_stream) ;
205+ let ( mut outbound_reader, mut outbound_writer) = outbound. into_split ( ) ;
206+
207+ let client_to_server = tokio:: io:: copy ( & mut inbound_reader, & mut outbound_writer) ;
208+ let server_to_client = tokio:: io:: copy ( & mut outbound_reader, & mut inbound_writer) ;
209+ tokio:: try_join!( client_to_server, server_to_client) ?;
210+ Ok ( ( ) )
211+ }
192212}
193213
194214pub struct ProxyClient < L , R >
@@ -284,78 +304,104 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
284304 let cert_chain = self . cert_chain . clone ( ) ;
285305
286306 tokio:: spawn ( async move {
287- let out = TcpStream :: connect ( target) . await . unwrap ( ) ;
288- let mut tls_stream = connector. connect ( target_name, out) . await . unwrap ( ) ;
307+ if let Err ( err) = Self :: handle_connection (
308+ inbound,
309+ connector,
310+ target,
311+ target_name,
312+ cert_chain,
313+ local_attestation_platform,
314+ remote_attestation_platform,
315+ )
316+ . await
317+ {
318+ eprintln ! ( "Failed to handle connection: {err}" ) ;
319+ }
320+ } ) ;
289321
290- let ( _io, server_connection) = tls_stream. get_ref ( ) ;
322+ Ok ( ( ) )
323+ }
291324
292- let mut exporter = [ 0u8 ; 32 ] ;
293- server_connection
294- . export_keying_material (
295- & mut exporter,
296- EXPORTER_LABEL ,
297- None , // context
298- )
299- . unwrap ( ) ;
325+ pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
326+ self . inner . listener . local_addr ( )
327+ }
300328
301- let remote_cert_chain = server_connection. peer_certificates ( ) . unwrap ( ) . to_owned ( ) ;
329+ async fn handle_connection (
330+ inbound : TcpStream ,
331+ connector : TlsConnector ,
332+ target : SocketAddr ,
333+ target_name : ServerName < ' static > ,
334+ cert_chain : Option < Vec < CertificateDer < ' static > > > ,
335+ local_attestation_platform : L ,
336+ remote_attestation_platform : R ,
337+ ) -> Result < ( ) , ProxyError > {
338+ let out = TcpStream :: connect ( target) . await ?;
339+ let mut tls_stream = connector. connect ( target_name, out) . await ?;
302340
303- let mut length_bytes = [ 0 ; 4 ] ;
304- tls_stream. read_exact ( & mut length_bytes) . await . unwrap ( ) ;
305- let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) . unwrap ( ) ;
341+ let ( _io, server_connection) = tls_stream. get_ref ( ) ;
306342
307- let mut buf = vec ! [ 0 ; length] ;
308- tls_stream. read_exact ( & mut buf) . await . unwrap ( ) ;
343+ let mut exporter = [ 0u8 ; 32 ] ;
344+ server_connection. export_keying_material (
345+ & mut exporter,
346+ EXPORTER_LABEL ,
347+ None , // context
348+ ) ?;
309349
310- if remote_attestation_platform. is_cvm ( ) {
311- remote_attestation_platform
312- . verify_attestation ( buf, & remote_cert_chain, exporter)
313- . unwrap ( ) ;
314- }
350+ let remote_cert_chain = server_connection
351+ . peer_certificates ( )
352+ . ok_or ( ProxyError :: NoCertificate ) ?
353+ . to_owned ( ) ;
315354
316- let attestation = if local_attestation_platform. is_cvm ( ) {
317- local_attestation_platform
318- . create_attestation ( & cert_chain. unwrap ( ) , exporter)
319- . unwrap ( )
320- } else {
321- Vec :: new ( )
322- } ;
355+ let mut length_bytes = [ 0 ; 4 ] ;
356+ tls_stream. read_exact ( & mut length_bytes) . await ?;
357+ let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) ?;
323358
324- let attestation_length_prefix = length_prefix ( & attestation) ;
359+ let mut buf = vec ! [ 0 ; length] ;
360+ tls_stream. read_exact ( & mut buf) . await ?;
325361
326- tls_stream
327- . write_all ( & attestation_length_prefix)
328- . await
329- . unwrap ( ) ;
362+ if remote_attestation_platform. is_cvm ( ) {
363+ remote_attestation_platform. verify_attestation ( buf, & remote_cert_chain, exporter) ?;
364+ }
330365
331- tls_stream. write_all ( & attestation) . await . unwrap ( ) ;
366+ let attestation = if local_attestation_platform. is_cvm ( ) {
367+ local_attestation_platform
368+ . create_attestation ( & cert_chain. ok_or ( ProxyError :: NoClientAuth ) ?, exporter) ?
369+ } else {
370+ Vec :: new ( )
371+ } ;
332372
333- let ( mut inbound_reader, mut inbound_writer) = inbound. into_split ( ) ;
334- let ( mut outbound_reader, mut outbound_writer) = tokio:: io:: split ( tls_stream) ;
373+ let attestation_length_prefix = length_prefix ( & attestation) ;
335374
336- let client_to_server = tokio:: io:: copy ( & mut inbound_reader, & mut outbound_writer) ;
337- let server_to_client = tokio:: io:: copy ( & mut outbound_reader, & mut inbound_writer) ;
338- tokio:: try_join!( client_to_server, server_to_client) . unwrap ( ) ;
339- } ) ;
375+ tls_stream. write_all ( & attestation_length_prefix) . await ?;
340376
341- Ok ( ( ) )
342- }
377+ tls_stream. write_all ( & attestation) . await ?;
343378
344- pub fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > {
345- self . inner . listener . local_addr ( )
379+ let ( mut inbound_reader, mut inbound_writer) = inbound. into_split ( ) ;
380+ let ( mut outbound_reader, mut outbound_writer) = tokio:: io:: split ( tls_stream) ;
381+
382+ let client_to_server = tokio:: io:: copy ( & mut inbound_reader, & mut outbound_writer) ;
383+ let server_to_client = tokio:: io:: copy ( & mut outbound_reader, & mut inbound_writer) ;
384+ tokio:: try_join!( client_to_server, server_to_client) ?;
385+ Ok ( ( ) )
346386 }
347387}
348388
349389#[ derive( Error , Debug ) ]
350390pub enum ProxyError {
351391 #[ error( "Client auth is required when the client is running in a CVM" ) ]
352392 NoClientAuth ,
393+ #[ error( "Failed to get server ceritifcate" ) ]
394+ NoCertificate ,
353395 #[ error( "TLS: {0}" ) ]
354396 Rustls ( #[ from] tokio_rustls:: rustls:: Error ) ,
355397 #[ error( "Verifier builder: {0}" ) ]
356398 VerifierBuilder ( #[ from] VerifierBuilderError ) ,
357399 #[ error( "IO: {0}" ) ]
358400 Io ( #[ from] std:: io:: Error ) ,
401+ #[ error( "Attestation: {0}" ) ]
402+ Attestation ( #[ from] AttestationError ) ,
403+ #[ error( "Integer conversion: {0}" ) ]
404+ IntConversion ( #[ from] TryFromIntError ) ,
359405}
360406
361407fn length_prefix ( input : & [ u8 ] ) -> [ u8 ; 4 ] {
0 commit comments