Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/attestation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ pub struct AttestationVerifier {
///
/// If this is empty, anything will be accepted - but measurements are always injected into HTTP
/// headers, so that they can be verified upstream
accepted_measurements: Vec<MeasurementRecord>,
pub accepted_measurements: Vec<MeasurementRecord>,
/// A PCCS service to use - defaults to Intel PCS
pccs_url: Option<String>,
pub pccs_url: Option<String>,
}

impl AttestationVerifier {
Expand Down Expand Up @@ -202,6 +202,9 @@ impl AttestationVerifier {
.await?
}
AttestationType::None => {
if self.has_remote_attestion() {
return Err(AttestationError::AttestationTypeNotAccepted);
}
if attestation_payload.attestation.is_empty() {
return Ok(None);
} else {
Expand All @@ -216,7 +219,8 @@ impl AttestationVerifier {
// look through all our accepted measurements
self.accepted_measurements
.iter()
.find(|a| a.attestation_type == attestation_type && a.measurements == measurements);
.find(|a| a.attestation_type == attestation_type && a.measurements == measurements)
.ok_or(AttestationError::MeasurementsNotAccepted)?;

Ok(Some(measurements))
}
Expand Down Expand Up @@ -409,4 +413,8 @@ pub enum AttestationError {
QuoteParse(#[from] QuoteParseError),
#[error("Attestation type not supported")]
AttestationTypeNotSupported,
#[error("Attestation type not accepted")]
AttestationTypeNotAccepted,
#[error("Measurements not accepted")]
MeasurementsNotAccepted,
}
223 changes: 209 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,8 @@ impl ProxyServer {
let service = service_fn(move |mut req| {
// If we have measurements, from the remote peer, add them to the request header
let measurements = measurements.clone();
let headers = req.headers_mut();
if let Some(measurements) = measurements {
let headers = req.headers_mut();

match measurements.to_header_format() {
Ok(header_value) => {
headers.insert(MEASUREMENT_HEADER, header_value);
Expand All @@ -261,12 +260,12 @@ impl ProxyServer {
error!("Failed to encode measurement values: {e}");
}
}
headers.insert(
ATTESTATION_TYPE_HEADER,
HeaderValue::from_str(remote_attestation_type.as_str())
.expect("Attestation type should be able to be encoded as a header value"),
);
}
headers.insert(
ATTESTATION_TYPE_HEADER,
HeaderValue::from_str(remote_attestation_type.as_str())
.expect("Attestation type should be able to be encoded as a header value"),
);

async move {
match Self::handle_http_request(req, target).await {
Expand Down Expand Up @@ -330,6 +329,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
}

/// A proxy client which forwards http traffic to a proxy-server
#[derive(Debug)]
pub struct ProxyClient {
/// The underlying TCP listener
listener: TcpListener,
Expand Down Expand Up @@ -438,8 +438,8 @@ impl ProxyClient {
Ok(mut resp) => {
// If we have measurements from the proxy-server, inject them into the
// response header
let headers = resp.headers_mut();
if let Some(measurements) = measurements.clone() {
let headers = resp.headers_mut();
match measurements.to_header_format() {
Ok(header_value) => {
headers.insert(MEASUREMENT_HEADER, header_value);
Expand All @@ -450,12 +450,13 @@ impl ProxyClient {
error!("Failed to encode measurement values: {e}");
}
}
headers.insert(
ATTESTATION_TYPE_HEADER,
HeaderValue::from_str(remote_attestation_type.as_str())
.expect("Attestation type should be able to be encoded as a header value"),
);
}
headers.insert(
ATTESTATION_TYPE_HEADER,
HeaderValue::from_str(remote_attestation_type.as_str()).expect(
"Attestation type should be able to be encoded as a header value",
),
);
(Ok(resp.map(|b| b.boxed())), false)
}
Err(e) => {
Expand Down Expand Up @@ -817,14 +818,19 @@ where

#[cfg(test)]
mod tests {
use crate::attestation::measurements::{
CvmImageMeasurements, MeasurementRecord, PlatformMeasurements,
};

use super::*;
use test_helpers::{
default_measurements, example_http_service, example_service, generate_certificate_chain,
generate_tls_config, generate_tls_config_with_client_auth,
};

// Server has mock DCAP, client has no attestation and no client auth
#[tokio::test]
async fn http_proxy() {
async fn http_proxy_with_server_attestation() {
let target_addr = example_http_service().await;

let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
Expand Down Expand Up @@ -886,6 +892,89 @@ mod tests {
assert_eq!(res_body, "No measurements");
}

// Server has no attestation, client has mock DCAP and client auth
#[tokio::test]
async fn http_proxy_client_attestation() {
let target_addr = example_http_service().await;

let (server_cert_chain, server_private_key) =
generate_certificate_chain("127.0.0.1".parse().unwrap());
let (client_cert_chain, client_private_key) =
generate_certificate_chain("127.0.0.1".parse().unwrap());

let (
(_client_tls_server_config, client_tls_client_config),
(server_tls_server_config, _server_tls_client_config),
) = generate_tls_config_with_client_auth(
client_cert_chain.clone(),
client_private_key,
server_cert_chain.clone(),
server_private_key,
);

let proxy_server = ProxyServer::new_with_tls_config(
server_cert_chain,
server_tls_server_config,
"127.0.0.1:0",
target_addr,
Arc::new(NoQuoteGenerator),
AttestationVerifier::mock(),
)
.await
.unwrap();

let proxy_addr = proxy_server.local_addr().unwrap();

tokio::spawn(async move {
// Accept one connection, then finish
proxy_server.accept().await.unwrap();
});

let proxy_client = ProxyClient::new_with_tls_config(
client_tls_client_config,
"127.0.0.1:0",
proxy_addr.to_string(),
Arc::new(DcapTdxQuoteGenerator {
attestation_type: AttestationType::DcapTdx,
}),
AttestationVerifier::do_not_verify(),
Some(client_cert_chain),
)
.await
.unwrap();

let proxy_client_addr = proxy_client.local_addr().unwrap();

tokio::spawn(async move {
// Accept two connections, then finish
proxy_client.accept().await.unwrap();
proxy_client.accept().await.unwrap();
});

let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string()))
.await
.unwrap();

// We expect no measurements from the server
let headers = res.headers();
assert!(headers.get(MEASUREMENT_HEADER).is_none());

let attestation_type = headers
.get(ATTESTATION_TYPE_HEADER)
.unwrap()
.to_str()
.unwrap();
assert_eq!(attestation_type, AttestationType::None.as_str());

let res_body = res.text().await.unwrap();

// The response body shows us what was in the request header (as the test http server
// handler puts them there)
let measurements = Measurements::from_header_format(&res_body).unwrap();
assert_eq!(measurements, default_measurements());
}

// Server has mock DCAP, client has mock DCAP and client auth
#[tokio::test]
async fn http_proxy_mutual_attestation() {
let target_addr = example_http_service().await;
Expand Down Expand Up @@ -994,6 +1083,7 @@ mod tests {
assert_eq!(measurements, default_measurements());
}

// Server has mock DCAP, client no attestation - just get the server certificate
#[tokio::test]
async fn test_get_tls_cert() {
let target_addr = example_service().await;
Expand Down Expand Up @@ -1030,4 +1120,109 @@ mod tests {

assert_eq!(retrieved_chain, cert_chain);
}

// Negative test - server does not provide attestation but client requires it
// Server has no attestaion, client has no attestation and no client auth
#[tokio::test]
async fn fails_on_no_attestation_when_expected() {
let target_addr = example_http_service().await;

let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);

let proxy_server = ProxyServer::new_with_tls_config(
cert_chain,
server_config,
"127.0.0.1:0",
target_addr,
Arc::new(NoQuoteGenerator),
AttestationVerifier::do_not_verify(),
)
.await
.unwrap();

let proxy_addr = proxy_server.local_addr().unwrap();

tokio::spawn(async move {
proxy_server.accept().await.unwrap();
});

let proxy_client_result = ProxyClient::new_with_tls_config(
client_config,
"127.0.0.1:0".to_string(),
proxy_addr.to_string(),
Arc::new(NoQuoteGenerator),
AttestationVerifier::mock(),
None,
)
.await;

assert!(matches!(
proxy_client_result.unwrap_err(),
ProxyError::Attestation(AttestationError::AttestationTypeNotAccepted)
));
}

// Negative test - server does not provide attestation but client requires it
// Server has no attestaion, client has no attestation and no client auth
#[tokio::test]
async fn fails_on_bad_measurements() {
let target_addr = example_http_service().await;

let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);

let proxy_server = ProxyServer::new_with_tls_config(
cert_chain,
server_config,
"127.0.0.1:0",
target_addr,
Arc::new(DcapTdxQuoteGenerator {
attestation_type: AttestationType::DcapTdx,
}),
AttestationVerifier::do_not_verify(),
)
.await
.unwrap();

let proxy_addr = proxy_server.local_addr().unwrap();

tokio::spawn(async move {
proxy_server.accept().await.unwrap();
});

let attestation_verifier = AttestationVerifier {
accepted_measurements: vec![MeasurementRecord {
attestation_type: AttestationType::DcapTdx,
measurement_id: "test".to_string(),
measurements: Measurements {
platform: PlatformMeasurements {
mrtd: [0; 48],
rtmr0: [0; 48],
},
cvm_image: CvmImageMeasurements {
rtmr1: [1; 48], // This differs from the mock measurements given
rtmr2: [0; 48],
rtmr3: [0; 48],
},
},
}],
pccs_url: None,
};

let proxy_client_result = ProxyClient::new_with_tls_config(
client_config,
"127.0.0.1:0".to_string(),
proxy_addr.to_string(),
Arc::new(NoQuoteGenerator),
attestation_verifier,
None,
)
.await;

assert!(matches!(
proxy_client_result.unwrap_err(),
ProxyError::Attestation(AttestationError::MeasurementsNotAccepted)
));
}
}