diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index 0cb9b0e..daf1791 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -145,9 +145,9 @@ pub struct AttestationVerifier { /// /// If this is empty, anything will be accepted - but measurements are always injected into HTTP /// headers, so that they can be verified upstream - accepted_measurements: Vec, + pub accepted_measurements: Vec, /// A PCCS service to use - defaults to Intel PCS - pccs_url: Option, + pub pccs_url: Option, } impl AttestationVerifier { @@ -202,6 +202,9 @@ impl AttestationVerifier { .await? } AttestationType::None => { + if self.has_remote_attestion() { + return Err(AttestationError::AttestationTypeNotAccepted); + } if attestation_payload.attestation.is_empty() { return Ok(None); } else { @@ -216,7 +219,8 @@ impl AttestationVerifier { // look through all our accepted measurements self.accepted_measurements .iter() - .find(|a| a.attestation_type == attestation_type && a.measurements == measurements); + .find(|a| a.attestation_type == attestation_type && a.measurements == measurements) + .ok_or(AttestationError::MeasurementsNotAccepted)?; Ok(Some(measurements)) } @@ -409,4 +413,8 @@ pub enum AttestationError { QuoteParse(#[from] QuoteParseError), #[error("Attestation type not supported")] AttestationTypeNotSupported, + #[error("Attestation type not accepted")] + AttestationTypeNotAccepted, + #[error("Measurements not accepted")] + MeasurementsNotAccepted, } diff --git a/src/lib.rs b/src/lib.rs index 9bb9105..3b67981 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -248,9 +248,8 @@ impl ProxyServer { let service = service_fn(move |mut req| { // If we have measurements, from the remote peer, add them to the request header let measurements = measurements.clone(); + let headers = req.headers_mut(); if let Some(measurements) = measurements { - let headers = req.headers_mut(); - match measurements.to_header_format() { Ok(header_value) => { headers.insert(MEASUREMENT_HEADER, header_value); @@ -261,12 +260,12 @@ impl ProxyServer { error!("Failed to encode measurement values: {e}"); } } - headers.insert( - ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()) - .expect("Attestation type should be able to be encoded as a header value"), - ); } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()) + .expect("Attestation type should be able to be encoded as a header value"), + ); async move { match Self::handle_http_request(req, target).await { @@ -330,6 +329,7 @@ fn full>(chunk: T) -> BoxBody { } /// A proxy client which forwards http traffic to a proxy-server +#[derive(Debug)] pub struct ProxyClient { /// The underlying TCP listener listener: TcpListener, @@ -438,8 +438,8 @@ impl ProxyClient { Ok(mut resp) => { // If we have measurements from the proxy-server, inject them into the // response header + let headers = resp.headers_mut(); if let Some(measurements) = measurements.clone() { - let headers = resp.headers_mut(); match measurements.to_header_format() { Ok(header_value) => { headers.insert(MEASUREMENT_HEADER, header_value); @@ -450,12 +450,13 @@ impl ProxyClient { error!("Failed to encode measurement values: {e}"); } } - headers.insert( - ATTESTATION_TYPE_HEADER, - HeaderValue::from_str(remote_attestation_type.as_str()) - .expect("Attestation type should be able to be encoded as a header value"), - ); } + headers.insert( + ATTESTATION_TYPE_HEADER, + HeaderValue::from_str(remote_attestation_type.as_str()).expect( + "Attestation type should be able to be encoded as a header value", + ), + ); (Ok(resp.map(|b| b.boxed())), false) } Err(e) => { @@ -817,14 +818,19 @@ where #[cfg(test)] mod tests { + use crate::attestation::measurements::{ + CvmImageMeasurements, MeasurementRecord, PlatformMeasurements, + }; + use super::*; use test_helpers::{ default_measurements, example_http_service, example_service, generate_certificate_chain, generate_tls_config, generate_tls_config_with_client_auth, }; + // Server has mock DCAP, client has no attestation and no client auth #[tokio::test] - async fn http_proxy() { + async fn http_proxy_with_server_attestation() { let target_addr = example_http_service().await; let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); @@ -886,6 +892,89 @@ mod tests { assert_eq!(res_body, "No measurements"); } + // Server has no attestation, client has mock DCAP and client auth + #[tokio::test] + async fn http_proxy_client_attestation() { + let target_addr = example_http_service().await; + + let (server_cert_chain, server_private_key) = + generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (client_cert_chain, client_private_key) = + generate_certificate_chain("127.0.0.1".parse().unwrap()); + + let ( + (_client_tls_server_config, client_tls_client_config), + (server_tls_server_config, _server_tls_client_config), + ) = generate_tls_config_with_client_auth( + client_cert_chain.clone(), + client_private_key, + server_cert_chain.clone(), + server_private_key, + ); + + let proxy_server = ProxyServer::new_with_tls_config( + server_cert_chain, + server_tls_server_config, + "127.0.0.1:0", + target_addr, + Arc::new(NoQuoteGenerator), + AttestationVerifier::mock(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + // Accept one connection, then finish + proxy_server.accept().await.unwrap(); + }); + + let proxy_client = ProxyClient::new_with_tls_config( + client_tls_client_config, + "127.0.0.1:0", + proxy_addr.to_string(), + Arc::new(DcapTdxQuoteGenerator { + attestation_type: AttestationType::DcapTdx, + }), + AttestationVerifier::do_not_verify(), + Some(client_cert_chain), + ) + .await + .unwrap(); + + let proxy_client_addr = proxy_client.local_addr().unwrap(); + + tokio::spawn(async move { + // Accept two connections, then finish + proxy_client.accept().await.unwrap(); + proxy_client.accept().await.unwrap(); + }); + + let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string())) + .await + .unwrap(); + + // We expect no measurements from the server + let headers = res.headers(); + assert!(headers.get(MEASUREMENT_HEADER).is_none()); + + let attestation_type = headers + .get(ATTESTATION_TYPE_HEADER) + .unwrap() + .to_str() + .unwrap(); + assert_eq!(attestation_type, AttestationType::None.as_str()); + + let res_body = res.text().await.unwrap(); + + // The response body shows us what was in the request header (as the test http server + // handler puts them there) + let measurements = Measurements::from_header_format(&res_body).unwrap(); + assert_eq!(measurements, default_measurements()); + } + + // Server has mock DCAP, client has mock DCAP and client auth #[tokio::test] async fn http_proxy_mutual_attestation() { let target_addr = example_http_service().await; @@ -994,6 +1083,7 @@ mod tests { assert_eq!(measurements, default_measurements()); } + // Server has mock DCAP, client no attestation - just get the server certificate #[tokio::test] async fn test_get_tls_cert() { let target_addr = example_service().await; @@ -1030,4 +1120,109 @@ mod tests { assert_eq!(retrieved_chain, cert_chain); } + + // Negative test - server does not provide attestation but client requires it + // Server has no attestaion, client has no attestation and no client auth + #[tokio::test] + async fn fails_on_no_attestation_when_expected() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr, + Arc::new(NoQuoteGenerator), + AttestationVerifier::do_not_verify(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let proxy_client_result = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), + Arc::new(NoQuoteGenerator), + AttestationVerifier::mock(), + None, + ) + .await; + + assert!(matches!( + proxy_client_result.unwrap_err(), + ProxyError::Attestation(AttestationError::AttestationTypeNotAccepted) + )); + } + + // Negative test - server does not provide attestation but client requires it + // Server has no attestaion, client has no attestation and no client auth + #[tokio::test] + async fn fails_on_bad_measurements() { + let target_addr = example_http_service().await; + + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let proxy_server = ProxyServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + target_addr, + Arc::new(DcapTdxQuoteGenerator { + attestation_type: AttestationType::DcapTdx, + }), + AttestationVerifier::do_not_verify(), + ) + .await + .unwrap(); + + let proxy_addr = proxy_server.local_addr().unwrap(); + + tokio::spawn(async move { + proxy_server.accept().await.unwrap(); + }); + + let attestation_verifier = AttestationVerifier { + accepted_measurements: vec![MeasurementRecord { + attestation_type: AttestationType::DcapTdx, + measurement_id: "test".to_string(), + measurements: Measurements { + platform: PlatformMeasurements { + mrtd: [0; 48], + rtmr0: [0; 48], + }, + cvm_image: CvmImageMeasurements { + rtmr1: [1; 48], // This differs from the mock measurements given + rtmr2: [0; 48], + rtmr3: [0; 48], + }, + }, + }], + pccs_url: None, + }; + + let proxy_client_result = ProxyClient::new_with_tls_config( + client_config, + "127.0.0.1:0".to_string(), + proxy_addr.to_string(), + Arc::new(NoQuoteGenerator), + attestation_verifier, + None, + ) + .await; + + assert!(matches!( + proxy_client_result.unwrap_err(), + ProxyError::Attestation(AttestationError::MeasurementsNotAccepted) + )); + } }