1- use crate :: attestation:: {
2- measurements:: MultiMeasurements , AttestationError , AttestationGenerator , AttestationType ,
1+ use crate :: {
2+ attestation:: {
3+ measurements:: MultiMeasurements , AttestationError , AttestationExchangeMessage ,
4+ AttestationGenerator , AttestationType , AttestationVerifier ,
5+ } ,
6+ host_to_host_with_port,
37} ;
48use parity_scale_codec:: { Decode , Encode } ;
59use sha2:: { Digest , Sha256 } ;
@@ -18,8 +22,6 @@ use tokio_rustls::{
1822 TlsAcceptor , TlsConnector ,
1923} ;
2024
21- use crate :: attestation:: { AttestationExchangeMessage , AttestationVerifier } ;
22-
2325/// This makes it possible to add breaking protocol changes and provide backwards compatibility.
2426/// When adding more supported versions, note that ordering is important. ALPN will pick the first
2527/// protocol which both parties support - so newer supported versions should come first.
@@ -297,7 +299,7 @@ impl AttestedTlsClient {
297299 /// Connect to an attested-tls-server, do TLS handshake and attestation exchange
298300 pub async fn connect (
299301 & self ,
300- target : String ,
302+ target : & str ,
301303 ) -> Result <
302304 (
303305 tokio_rustls:: client:: TlsStream < tokio:: net:: TcpStream > ,
@@ -310,7 +312,7 @@ impl AttestedTlsClient {
310312 let out = TcpStream :: connect ( & target) . await ?;
311313 let mut tls_stream = self
312314 . connector
313- . connect ( server_name_from_host ( & target) ?, out)
315+ . connect ( server_name_from_host ( target) ?, out)
314316 . await ?;
315317
316318 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
@@ -371,82 +373,61 @@ impl AttestedTlsClient {
371373
372374 Ok ( ( tls_stream, measurements, remote_attestation_type) )
373375 }
376+
377+ /// Connect to an attested TLS server, retrieve the remote TLS certificate and return it
378+ pub async fn get_tls_cert (
379+ & self ,
380+ server_name : & str ,
381+ ) -> Result < Vec < CertificateDer < ' static > > , AttestedTlsError > {
382+ let ( mut tls_stream, _, _) = self . connect ( server_name) . await ?;
383+
384+ let ( _io, server_connection) = tls_stream. get_ref ( ) ;
385+
386+ let remote_cert_chain = server_connection
387+ . peer_certificates ( )
388+ . ok_or ( AttestedTlsError :: NoCertificate ) ?
389+ . to_owned ( ) ;
390+
391+ tls_stream. shutdown ( ) . await ?;
392+
393+ Ok ( remote_cert_chain)
394+ }
374395}
375396
376397/// A client which just gets the attested remote certificate, with no client authentication
377398pub async fn get_tls_cert (
378399 server_name : String ,
379400 attestation_verifier : AttestationVerifier ,
380- remote_certificate : Option < CertificateDer < ' _ > > ,
401+ remote_certificate : Option < CertificateDer < ' static > > ,
381402) -> Result < Vec < CertificateDer < ' static > > , AttestedTlsError > {
382403 tracing:: debug!( "Getting remote TLS cert" ) ;
383- // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots
384- let root_store = match remote_certificate {
385- Some ( remote_certificate) => {
386- let mut root_store = RootCertStore :: empty ( ) ;
387- root_store. add ( remote_certificate) ?;
388- root_store
389- }
390- None => RootCertStore :: from_iter ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ,
391- } ;
392-
393- let mut client_config = ClientConfig :: builder ( )
394- . with_root_certificates ( root_store)
395- . with_no_client_auth ( ) ;
396-
397- client_config. alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS
398- . into_iter ( )
399- . map ( |p| p. to_vec ( ) )
400- . collect ( ) ;
401-
402- get_tls_cert_with_config ( server_name, attestation_verifier, client_config. into ( ) ) . await
404+ let attested_tls_client = AttestedTlsClient :: new (
405+ None ,
406+ AttestationGenerator :: with_no_attestation ( ) ,
407+ attestation_verifier,
408+ remote_certificate,
409+ )
410+ . await ?;
411+ attested_tls_client
412+ . get_tls_cert ( & host_to_host_with_port ( & server_name) )
413+ . await
403414}
404415
405- // TODO this could use AttestedTlsClient to avoid repeating code
416+ /// Helper for testing getting remote certificate
417+ #[ cfg( test) ]
406418pub ( crate ) async fn get_tls_cert_with_config (
407- server_name : String ,
419+ server_name : & str ,
408420 attestation_verifier : AttestationVerifier ,
409421 client_config : Arc < ClientConfig > ,
410422) -> Result < Vec < CertificateDer < ' static > > , AttestedTlsError > {
411- let connector = TlsConnector :: from ( client_config) ;
412-
413- let out = TcpStream :: connect ( host_to_host_with_port ( & server_name) ) . await ?;
414- let mut tls_stream = connector
415- . connect ( server_name_from_host ( & server_name) ?, out)
416- . await ?;
417-
418- let ( _io, server_connection) = tls_stream. get_ref ( ) ;
419-
420- let mut exporter = [ 0u8 ; 32 ] ;
421- server_connection. export_keying_material (
422- & mut exporter,
423- EXPORTER_LABEL ,
424- None , // context
425- ) ?;
426-
427- let remote_cert_chain = server_connection
428- . peer_certificates ( )
429- . ok_or ( AttestedTlsError :: NoCertificate ) ?
430- . to_owned ( ) ;
431-
432- let mut length_bytes = [ 0 ; 4 ] ;
433- tls_stream. read_exact ( & mut length_bytes) . await ?;
434- let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) ?;
435-
436- let mut buf = vec ! [ 0 ; length] ;
437- tls_stream. read_exact ( & mut buf) . await ?;
438-
439- let remote_attestation_message = AttestationExchangeMessage :: decode ( & mut & buf[ ..] ) ?;
440-
441- let remote_input_data = compute_report_input ( Some ( & remote_cert_chain) , exporter) ?;
442-
443- let _measurements = attestation_verifier
444- . verify_attestation ( remote_attestation_message, remote_input_data)
445- . await ?;
446-
447- tls_stream. shutdown ( ) . await ?;
448-
449- Ok ( remote_cert_chain)
423+ let attested_tls_client = AttestedTlsClient :: new_with_tls_config (
424+ client_config,
425+ AttestationGenerator :: with_no_attestation ( ) ,
426+ attestation_verifier,
427+ None ,
428+ )
429+ . await ?;
430+ attested_tls_client. get_tls_cert ( server_name) . await
450431}
451432
452433/// Given a certificate chain and an exporter (session key material), build the quote input value
@@ -507,15 +488,6 @@ fn length_prefix(input: &[u8]) -> [u8; 4] {
507488 len. to_be_bytes ( )
508489}
509490
510- /// If no port was provided, default to 443
511- fn host_to_host_with_port ( host : & str ) -> String {
512- if host. contains ( ':' ) {
513- host. to_string ( )
514- } else {
515- format ! ( "{host}:443" )
516- }
517- }
518-
519491/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part
520492fn server_name_from_host (
521493 host : & str ,
0 commit comments