Skip to content

Commit d872a00

Browse files
committed
Deduplicate get_tls_cert
1 parent 82b9bf5 commit d872a00

File tree

2 files changed

+60
-88
lines changed

2 files changed

+60
-88
lines changed

src/attested_tls.rs

Lines changed: 49 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1-
use crate::attestation::{
2-
measurements::MultiMeasurements, AttestationError, AttestationGenerator, AttestationType,
1+
use crate::{
2+
attestation::{
3+
measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage,
4+
AttestationGenerator, AttestationType, AttestationVerifier,
5+
},
6+
host_to_host_with_port,
37
};
48
use parity_scale_codec::{Decode, Encode};
59
use sha2::{Digest, Sha256};
@@ -18,8 +22,6 @@ use tokio_rustls::{
1822
TlsAcceptor, TlsConnector,
1923
};
2024

21-
use crate::attestation::{AttestationExchangeMessage, AttestationVerifier};
22-
2325
/// This makes it possible to add breaking protocol changes and provide backwards compatibility.
2426
/// When adding more supported versions, note that ordering is important. ALPN will pick the first
2527
/// protocol which both parties support - so newer supported versions should come first.
@@ -297,7 +299,7 @@ impl AttestedTlsClient {
297299
/// Connect to an attested-tls-server, do TLS handshake and attestation exchange
298300
pub async fn connect(
299301
&self,
300-
target: String,
302+
target: &str,
301303
) -> Result<
302304
(
303305
tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
@@ -310,7 +312,7 @@ impl AttestedTlsClient {
310312
let out = TcpStream::connect(&target).await?;
311313
let mut tls_stream = self
312314
.connector
313-
.connect(server_name_from_host(&target)?, out)
315+
.connect(server_name_from_host(target)?, out)
314316
.await?;
315317

316318
let (_io, server_connection) = tls_stream.get_ref();
@@ -371,82 +373,61 @@ impl AttestedTlsClient {
371373

372374
Ok((tls_stream, measurements, remote_attestation_type))
373375
}
376+
377+
/// Connect to an attested TLS server, retrieve the remote TLS certificate and return it
378+
pub async fn get_tls_cert(
379+
&self,
380+
server_name: &str,
381+
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
382+
let (mut tls_stream, _, _) = self.connect(server_name).await?;
383+
384+
let (_io, server_connection) = tls_stream.get_ref();
385+
386+
let remote_cert_chain = server_connection
387+
.peer_certificates()
388+
.ok_or(AttestedTlsError::NoCertificate)?
389+
.to_owned();
390+
391+
tls_stream.shutdown().await?;
392+
393+
Ok(remote_cert_chain)
394+
}
374395
}
375396

376397
/// A client which just gets the attested remote certificate, with no client authentication
377398
pub async fn get_tls_cert(
378399
server_name: String,
379400
attestation_verifier: AttestationVerifier,
380-
remote_certificate: Option<CertificateDer<'_>>,
401+
remote_certificate: Option<CertificateDer<'static>>,
381402
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
382403
tracing::debug!("Getting remote TLS cert");
383-
// If a remote CA cert was given, use it as the root store, otherwise use webpki_roots
384-
let root_store = match remote_certificate {
385-
Some(remote_certificate) => {
386-
let mut root_store = RootCertStore::empty();
387-
root_store.add(remote_certificate)?;
388-
root_store
389-
}
390-
None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()),
391-
};
392-
393-
let mut client_config = ClientConfig::builder()
394-
.with_root_certificates(root_store)
395-
.with_no_client_auth();
396-
397-
client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS
398-
.into_iter()
399-
.map(|p| p.to_vec())
400-
.collect();
401-
402-
get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await
404+
let attested_tls_client = AttestedTlsClient::new(
405+
None,
406+
AttestationGenerator::with_no_attestation(),
407+
attestation_verifier,
408+
remote_certificate,
409+
)
410+
.await?;
411+
attested_tls_client
412+
.get_tls_cert(&host_to_host_with_port(&server_name))
413+
.await
403414
}
404415

405-
// TODO this could use AttestedTlsClient to avoid repeating code
416+
/// Helper for testing getting remote certificate
417+
#[cfg(test)]
406418
pub(crate) async fn get_tls_cert_with_config(
407-
server_name: String,
419+
server_name: &str,
408420
attestation_verifier: AttestationVerifier,
409421
client_config: Arc<ClientConfig>,
410422
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
411-
let connector = TlsConnector::from(client_config);
412-
413-
let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?;
414-
let mut tls_stream = connector
415-
.connect(server_name_from_host(&server_name)?, out)
416-
.await?;
417-
418-
let (_io, server_connection) = tls_stream.get_ref();
419-
420-
let mut exporter = [0u8; 32];
421-
server_connection.export_keying_material(
422-
&mut exporter,
423-
EXPORTER_LABEL,
424-
None, // context
425-
)?;
426-
427-
let remote_cert_chain = server_connection
428-
.peer_certificates()
429-
.ok_or(AttestedTlsError::NoCertificate)?
430-
.to_owned();
431-
432-
let mut length_bytes = [0; 4];
433-
tls_stream.read_exact(&mut length_bytes).await?;
434-
let length: usize = u32::from_be_bytes(length_bytes).try_into()?;
435-
436-
let mut buf = vec![0; length];
437-
tls_stream.read_exact(&mut buf).await?;
438-
439-
let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?;
440-
441-
let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?;
442-
443-
let _measurements = attestation_verifier
444-
.verify_attestation(remote_attestation_message, remote_input_data)
445-
.await?;
446-
447-
tls_stream.shutdown().await?;
448-
449-
Ok(remote_cert_chain)
423+
let attested_tls_client = AttestedTlsClient::new_with_tls_config(
424+
client_config,
425+
AttestationGenerator::with_no_attestation(),
426+
attestation_verifier,
427+
None,
428+
)
429+
.await?;
430+
attested_tls_client.get_tls_cert(server_name).await
450431
}
451432

452433
/// Given a certificate chain and an exporter (session key material), build the quote input value
@@ -507,15 +488,6 @@ fn length_prefix(input: &[u8]) -> [u8; 4] {
507488
len.to_be_bytes()
508489
}
509490

510-
/// If no port was provided, default to 443
511-
fn host_to_host_with_port(host: &str) -> String {
512-
if host.contains(':') {
513-
host.to_string()
514-
} else {
515-
format!("{host}:443")
516-
}
517-
}
518-
519491
/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part
520492
fn server_name_from_host(
521493
host: &str,

src/lib.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ impl ProxyClient {
267267
)
268268
.await?;
269269

270-
Self::new_with_inner(address, attested_tls_client, server_name).await
270+
Self::new_with_inner(address, attested_tls_client, &server_name).await
271271
}
272272

273273
/// Create a new proxy client with given TLS configuration
@@ -290,7 +290,7 @@ impl ProxyClient {
290290
)
291291
.await?;
292292

293-
Self::new_with_inner(address, attested_tls_client, target_name).await
293+
Self::new_with_inner(address, attested_tls_client, &target_name).await
294294
}
295295

296296
/// Create a new proxy client with given TLS configuration
@@ -299,12 +299,12 @@ impl ProxyClient {
299299
async fn new_with_inner(
300300
address: impl ToSocketAddrs,
301301
attested_tls_client: AttestedTlsClient,
302-
target_name: String,
302+
target_name: &str,
303303
) -> Result<Self, ProxyError> {
304304
let listener = TcpListener::bind(address).await?;
305305

306306
// Process the hostname / port provided by the user
307-
let target = host_to_host_with_port(&target_name);
307+
let target = host_to_host_with_port(target_name);
308308

309309
// Channel for getting incoming requests from the source client
310310
let (requests_tx, mut requests_rx) = mpsc::channel::<(
@@ -316,7 +316,7 @@ impl ProxyClient {
316316

317317
// Connect to the proxy server and provide / verify attestation
318318
let (mut sender, mut measurements, mut remote_attestation_type) =
319-
Self::setup_connection_with_backoff(target.clone(), &attested_tls_client, true).await?;
319+
Self::setup_connection_with_backoff(&target, &attested_tls_client, true).await?;
320320

321321
let attested_tls_client_clone = attested_tls_client.clone();
322322
tokio::spawn(async move {
@@ -367,7 +367,7 @@ impl ProxyClient {
367367
// Reconnect to the server - retrying indefinately with a backoff
368368
(sender, measurements, remote_attestation_type) =
369369
Self::setup_connection_with_backoff(
370-
target.clone(),
370+
&target,
371371
&attested_tls_client_clone,
372372
false,
373373
)
@@ -438,15 +438,15 @@ impl ProxyClient {
438438
// Attempt connection and handshake with the proxy-server
439439
// If it fails retry with a backoff (indefinately)
440440
async fn setup_connection_with_backoff(
441-
target: String,
441+
target: &str,
442442
attested_tls_client: &AttestedTlsClient,
443443
should_bail: bool,
444444
) -> Result<(Http2Sender, Option<MultiMeasurements>, AttestationType), ProxyError> {
445445
let mut delay = Duration::from_secs(1);
446446
let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS);
447447

448448
loop {
449-
match Self::setup_connection(attested_tls_client, target.clone()).await {
449+
match Self::setup_connection(attested_tls_client, target).await {
450450
Ok(output) => {
451451
return Ok(output);
452452
}
@@ -469,7 +469,7 @@ impl ProxyClient {
469469
/// Connect to the proxy-server, do TLS handshake and remote attestation
470470
async fn setup_connection(
471471
inner: &AttestedTlsClient,
472-
target: String,
472+
target: &str,
473473
) -> Result<(Http2Sender, Option<MultiMeasurements>, AttestationType), ProxyError> {
474474
let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?;
475475

@@ -567,7 +567,7 @@ impl From<mpsc::error::SendError<RequestWithResponseSender>> for ProxyError {
567567
}
568568

569569
/// If no port was provided, default to 443
570-
fn host_to_host_with_port(host: &str) -> String {
570+
pub(crate) fn host_to_host_with_port(host: &str) -> String {
571571
if host.contains(':') {
572572
host.to_string()
573573
} else {
@@ -962,7 +962,7 @@ mod tests {
962962
});
963963

964964
let retrieved_chain = get_tls_cert_with_config(
965-
proxy_server_addr.to_string(),
965+
&proxy_server_addr.to_string(),
966966
AttestationVerifier::mock(),
967967
client_config,
968968
)

0 commit comments

Comments
 (0)