Skip to content

Commit b5b3742

Browse files
authored
Merge pull request #40 from flashbots/peg/test-client-only-attestaion
Add additional tests and fix attesation verification
2 parents e1c92e5 + 4c899f1 commit b5b3742

File tree

2 files changed

+220
-17
lines changed

2 files changed

+220
-17
lines changed

src/attestation/mod.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ pub struct AttestationVerifier {
145145
///
146146
/// If this is empty, anything will be accepted - but measurements are always injected into HTTP
147147
/// headers, so that they can be verified upstream
148-
accepted_measurements: Vec<MeasurementRecord>,
148+
pub accepted_measurements: Vec<MeasurementRecord>,
149149
/// A PCCS service to use - defaults to Intel PCS
150-
pccs_url: Option<String>,
150+
pub pccs_url: Option<String>,
151151
}
152152

153153
impl AttestationVerifier {
@@ -202,6 +202,9 @@ impl AttestationVerifier {
202202
.await?
203203
}
204204
AttestationType::None => {
205+
if self.has_remote_attestion() {
206+
return Err(AttestationError::AttestationTypeNotAccepted);
207+
}
205208
if attestation_payload.attestation.is_empty() {
206209
return Ok(None);
207210
} else {
@@ -216,7 +219,8 @@ impl AttestationVerifier {
216219
// look through all our accepted measurements
217220
self.accepted_measurements
218221
.iter()
219-
.find(|a| a.attestation_type == attestation_type && a.measurements == measurements);
222+
.find(|a| a.attestation_type == attestation_type && a.measurements == measurements)
223+
.ok_or(AttestationError::MeasurementsNotAccepted)?;
220224

221225
Ok(Some(measurements))
222226
}
@@ -409,4 +413,8 @@ pub enum AttestationError {
409413
QuoteParse(#[from] QuoteParseError),
410414
#[error("Attestation type not supported")]
411415
AttestationTypeNotSupported,
416+
#[error("Attestation type not accepted")]
417+
AttestationTypeNotAccepted,
418+
#[error("Measurements not accepted")]
419+
MeasurementsNotAccepted,
412420
}

src/lib.rs

Lines changed: 209 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,8 @@ impl ProxyServer {
248248
let service = service_fn(move |mut req| {
249249
// If we have measurements, from the remote peer, add them to the request header
250250
let measurements = measurements.clone();
251+
let headers = req.headers_mut();
251252
if let Some(measurements) = measurements {
252-
let headers = req.headers_mut();
253-
254253
match measurements.to_header_format() {
255254
Ok(header_value) => {
256255
headers.insert(MEASUREMENT_HEADER, header_value);
@@ -261,12 +260,12 @@ impl ProxyServer {
261260
error!("Failed to encode measurement values: {e}");
262261
}
263262
}
264-
headers.insert(
265-
ATTESTATION_TYPE_HEADER,
266-
HeaderValue::from_str(remote_attestation_type.as_str())
267-
.expect("Attestation type should be able to be encoded as a header value"),
268-
);
269263
}
264+
headers.insert(
265+
ATTESTATION_TYPE_HEADER,
266+
HeaderValue::from_str(remote_attestation_type.as_str())
267+
.expect("Attestation type should be able to be encoded as a header value"),
268+
);
270269

271270
async move {
272271
match Self::handle_http_request(req, target).await {
@@ -330,6 +329,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
330329
}
331330

332331
/// A proxy client which forwards http traffic to a proxy-server
332+
#[derive(Debug)]
333333
pub struct ProxyClient {
334334
/// The underlying TCP listener
335335
listener: TcpListener,
@@ -438,8 +438,8 @@ impl ProxyClient {
438438
Ok(mut resp) => {
439439
// If we have measurements from the proxy-server, inject them into the
440440
// response header
441+
let headers = resp.headers_mut();
441442
if let Some(measurements) = measurements.clone() {
442-
let headers = resp.headers_mut();
443443
match measurements.to_header_format() {
444444
Ok(header_value) => {
445445
headers.insert(MEASUREMENT_HEADER, header_value);
@@ -450,12 +450,13 @@ impl ProxyClient {
450450
error!("Failed to encode measurement values: {e}");
451451
}
452452
}
453-
headers.insert(
454-
ATTESTATION_TYPE_HEADER,
455-
HeaderValue::from_str(remote_attestation_type.as_str())
456-
.expect("Attestation type should be able to be encoded as a header value"),
457-
);
458453
}
454+
headers.insert(
455+
ATTESTATION_TYPE_HEADER,
456+
HeaderValue::from_str(remote_attestation_type.as_str()).expect(
457+
"Attestation type should be able to be encoded as a header value",
458+
),
459+
);
459460
(Ok(resp.map(|b| b.boxed())), false)
460461
}
461462
Err(e) => {
@@ -817,14 +818,19 @@ where
817818

818819
#[cfg(test)]
819820
mod tests {
821+
use crate::attestation::measurements::{
822+
CvmImageMeasurements, MeasurementRecord, PlatformMeasurements,
823+
};
824+
820825
use super::*;
821826
use test_helpers::{
822827
default_measurements, example_http_service, example_service, generate_certificate_chain,
823828
generate_tls_config, generate_tls_config_with_client_auth,
824829
};
825830

831+
// Server has mock DCAP, client has no attestation and no client auth
826832
#[tokio::test]
827-
async fn http_proxy() {
833+
async fn http_proxy_with_server_attestation() {
828834
let target_addr = example_http_service().await;
829835

830836
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
@@ -886,6 +892,89 @@ mod tests {
886892
assert_eq!(res_body, "No measurements");
887893
}
888894

895+
// Server has no attestation, client has mock DCAP and client auth
896+
#[tokio::test]
897+
async fn http_proxy_client_attestation() {
898+
let target_addr = example_http_service().await;
899+
900+
let (server_cert_chain, server_private_key) =
901+
generate_certificate_chain("127.0.0.1".parse().unwrap());
902+
let (client_cert_chain, client_private_key) =
903+
generate_certificate_chain("127.0.0.1".parse().unwrap());
904+
905+
let (
906+
(_client_tls_server_config, client_tls_client_config),
907+
(server_tls_server_config, _server_tls_client_config),
908+
) = generate_tls_config_with_client_auth(
909+
client_cert_chain.clone(),
910+
client_private_key,
911+
server_cert_chain.clone(),
912+
server_private_key,
913+
);
914+
915+
let proxy_server = ProxyServer::new_with_tls_config(
916+
server_cert_chain,
917+
server_tls_server_config,
918+
"127.0.0.1:0",
919+
target_addr,
920+
Arc::new(NoQuoteGenerator),
921+
AttestationVerifier::mock(),
922+
)
923+
.await
924+
.unwrap();
925+
926+
let proxy_addr = proxy_server.local_addr().unwrap();
927+
928+
tokio::spawn(async move {
929+
// Accept one connection, then finish
930+
proxy_server.accept().await.unwrap();
931+
});
932+
933+
let proxy_client = ProxyClient::new_with_tls_config(
934+
client_tls_client_config,
935+
"127.0.0.1:0",
936+
proxy_addr.to_string(),
937+
Arc::new(DcapTdxQuoteGenerator {
938+
attestation_type: AttestationType::DcapTdx,
939+
}),
940+
AttestationVerifier::do_not_verify(),
941+
Some(client_cert_chain),
942+
)
943+
.await
944+
.unwrap();
945+
946+
let proxy_client_addr = proxy_client.local_addr().unwrap();
947+
948+
tokio::spawn(async move {
949+
// Accept two connections, then finish
950+
proxy_client.accept().await.unwrap();
951+
proxy_client.accept().await.unwrap();
952+
});
953+
954+
let res = reqwest::get(format!("http://{}", proxy_client_addr.to_string()))
955+
.await
956+
.unwrap();
957+
958+
// We expect no measurements from the server
959+
let headers = res.headers();
960+
assert!(headers.get(MEASUREMENT_HEADER).is_none());
961+
962+
let attestation_type = headers
963+
.get(ATTESTATION_TYPE_HEADER)
964+
.unwrap()
965+
.to_str()
966+
.unwrap();
967+
assert_eq!(attestation_type, AttestationType::None.as_str());
968+
969+
let res_body = res.text().await.unwrap();
970+
971+
// The response body shows us what was in the request header (as the test http server
972+
// handler puts them there)
973+
let measurements = Measurements::from_header_format(&res_body).unwrap();
974+
assert_eq!(measurements, default_measurements());
975+
}
976+
977+
// Server has mock DCAP, client has mock DCAP and client auth
889978
#[tokio::test]
890979
async fn http_proxy_mutual_attestation() {
891980
let target_addr = example_http_service().await;
@@ -994,6 +1083,7 @@ mod tests {
9941083
assert_eq!(measurements, default_measurements());
9951084
}
9961085

1086+
// Server has mock DCAP, client no attestation - just get the server certificate
9971087
#[tokio::test]
9981088
async fn test_get_tls_cert() {
9991089
let target_addr = example_service().await;
@@ -1030,4 +1120,109 @@ mod tests {
10301120

10311121
assert_eq!(retrieved_chain, cert_chain);
10321122
}
1123+
1124+
// Negative test - server does not provide attestation but client requires it
1125+
// Server has no attestaion, client has no attestation and no client auth
1126+
#[tokio::test]
1127+
async fn fails_on_no_attestation_when_expected() {
1128+
let target_addr = example_http_service().await;
1129+
1130+
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
1131+
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
1132+
1133+
let proxy_server = ProxyServer::new_with_tls_config(
1134+
cert_chain,
1135+
server_config,
1136+
"127.0.0.1:0",
1137+
target_addr,
1138+
Arc::new(NoQuoteGenerator),
1139+
AttestationVerifier::do_not_verify(),
1140+
)
1141+
.await
1142+
.unwrap();
1143+
1144+
let proxy_addr = proxy_server.local_addr().unwrap();
1145+
1146+
tokio::spawn(async move {
1147+
proxy_server.accept().await.unwrap();
1148+
});
1149+
1150+
let proxy_client_result = ProxyClient::new_with_tls_config(
1151+
client_config,
1152+
"127.0.0.1:0".to_string(),
1153+
proxy_addr.to_string(),
1154+
Arc::new(NoQuoteGenerator),
1155+
AttestationVerifier::mock(),
1156+
None,
1157+
)
1158+
.await;
1159+
1160+
assert!(matches!(
1161+
proxy_client_result.unwrap_err(),
1162+
ProxyError::Attestation(AttestationError::AttestationTypeNotAccepted)
1163+
));
1164+
}
1165+
1166+
// Negative test - server does not provide attestation but client requires it
1167+
// Server has no attestaion, client has no attestation and no client auth
1168+
#[tokio::test]
1169+
async fn fails_on_bad_measurements() {
1170+
let target_addr = example_http_service().await;
1171+
1172+
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
1173+
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
1174+
1175+
let proxy_server = ProxyServer::new_with_tls_config(
1176+
cert_chain,
1177+
server_config,
1178+
"127.0.0.1:0",
1179+
target_addr,
1180+
Arc::new(DcapTdxQuoteGenerator {
1181+
attestation_type: AttestationType::DcapTdx,
1182+
}),
1183+
AttestationVerifier::do_not_verify(),
1184+
)
1185+
.await
1186+
.unwrap();
1187+
1188+
let proxy_addr = proxy_server.local_addr().unwrap();
1189+
1190+
tokio::spawn(async move {
1191+
proxy_server.accept().await.unwrap();
1192+
});
1193+
1194+
let attestation_verifier = AttestationVerifier {
1195+
accepted_measurements: vec![MeasurementRecord {
1196+
attestation_type: AttestationType::DcapTdx,
1197+
measurement_id: "test".to_string(),
1198+
measurements: Measurements {
1199+
platform: PlatformMeasurements {
1200+
mrtd: [0; 48],
1201+
rtmr0: [0; 48],
1202+
},
1203+
cvm_image: CvmImageMeasurements {
1204+
rtmr1: [1; 48], // This differs from the mock measurements given
1205+
rtmr2: [0; 48],
1206+
rtmr3: [0; 48],
1207+
},
1208+
},
1209+
}],
1210+
pccs_url: None,
1211+
};
1212+
1213+
let proxy_client_result = ProxyClient::new_with_tls_config(
1214+
client_config,
1215+
"127.0.0.1:0".to_string(),
1216+
proxy_addr.to_string(),
1217+
Arc::new(NoQuoteGenerator),
1218+
attestation_verifier,
1219+
None,
1220+
)
1221+
.await;
1222+
1223+
assert!(matches!(
1224+
proxy_client_result.unwrap_err(),
1225+
ProxyError::Attestation(AttestationError::MeasurementsNotAccepted)
1226+
));
1227+
}
10331228
}

0 commit comments

Comments
 (0)