Skip to content

Commit 889b43c

Browse files
committed
Allow giving measurement file as url, and improve measurement checking logic
1 parent 2fe13db commit 889b43c

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

src/attestation/measurements.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ pub enum MeasurementFormatError {
162162
BadRegisterIndex,
163163
#[error("ParseInt: {0}")]
164164
ParseInt(#[from] std::num::ParseIntError),
165+
#[error("Failed to read measurements from URL: {0}")]
166+
Reqwest(#[from] reqwest::Error),
165167
}
166168

167169
/// An accepted measurement value given in the measurements file
@@ -261,9 +263,11 @@ impl MeasurementPolicy {
261263
.any(|measurement_record| match measurements {
262264
MultiMeasurements::Dcap(dcap_measurements) => {
263265
if let MultiMeasurements::Dcap(d) = measurement_record.measurements.clone() {
264-
for (k, v) in dcap_measurements.iter() {
265-
if d.get(k).is_some_and(|x| x != v) {
266-
return false;
266+
// All measurements in our policy must be given and must match
267+
for (k, v) in d.iter() {
268+
match dcap_measurements.get(k) {
269+
Some(value) if value == v => {}
270+
_ => return false,
267271
}
268272
}
269273
return true;
@@ -272,9 +276,10 @@ impl MeasurementPolicy {
272276
}
273277
MultiMeasurements::Azure(azure_measurements) => {
274278
if let MultiMeasurements::Azure(a) = measurement_record.measurements.clone() {
275-
for (k, v) in azure_measurements.iter() {
276-
if a.get(k).is_some_and(|x| x != v) {
277-
return false;
279+
for (k, v) in a.iter() {
280+
match azure_measurements.get(k) {
281+
Some(value) if value == v => {}
282+
_ => return false,
278283
}
279284
}
280285
return true;
@@ -303,6 +308,16 @@ impl MeasurementPolicy {
303308
.any(|a| a.measurements == MultiMeasurements::NoAttestation)
304309
}
305310

311+
/// Given either a URL or the path to a file, parse the measurement policy from JSON
312+
pub async fn from_file_or_url(file_or_url: String) -> Result<Self, MeasurementFormatError> {
313+
if file_or_url.starts_with("https://") || file_or_url.starts_with("http://") {
314+
let measurements_json = reqwest::get(file_or_url).await?.bytes().await?;
315+
Self::from_json_bytes(measurements_json.to_vec()).await
316+
} else {
317+
Self::from_file(file_or_url.into()).await
318+
}
319+
}
320+
306321
/// Given the path to a JSON file containing measurements, return a [MeasurementPolicy]
307322
pub async fn from_file(measurement_file: PathBuf) -> Result<Self, MeasurementFormatError> {
308323
let measurements_json = tokio::fs::read(measurement_file).await?;
@@ -449,6 +464,14 @@ mod tests {
449464
.unwrap_err(),
450465
AttestationError::MeasurementsNotAccepted
451466
));
467+
468+
// A non-specific measurement fails
469+
assert!(matches!(
470+
specific_measurements
471+
.check_measurement(&MultiMeasurements::Azure(HashMap::new()))
472+
.unwrap_err(),
473+
AttestationError::MeasurementsNotAccepted
474+
));
452475
}
453476

454477
#[tokio::test]
@@ -471,4 +494,31 @@ mod tests {
471494
AttestationError::MeasurementsNotAccepted
472495
));
473496
}
497+
498+
#[tokio::test]
499+
async fn test_read_remote_buildernet_measurements() {
500+
// Check that the buildernet measurements are available and parse correctly
501+
let policy = MeasurementPolicy::from_file_or_url(
502+
"https://measurements.builder.flashbots.net".to_string(),
503+
)
504+
.await
505+
.unwrap();
506+
507+
assert!(!policy.accepted_measurements.is_empty());
508+
509+
assert!(matches!(
510+
policy
511+
.check_measurement(&MultiMeasurements::NoAttestation)
512+
.unwrap_err(),
513+
AttestationError::MeasurementsNotAccepted
514+
));
515+
516+
// A non-specific measurement fails
517+
assert!(matches!(
518+
policy
519+
.check_measurement(&MultiMeasurements::Azure(HashMap::new()))
520+
.unwrap_err(),
521+
AttestationError::MeasurementsNotAccepted
522+
));
523+
}
474524
}

src/main.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ use attested_tls_proxy::{
1717
struct Cli {
1818
#[clap(subcommand)]
1919
command: CliCommand,
20-
/// Optional path to file containing JSON measurements to be enforced on the remote party
20+
/// Path to file, or URL, containing JSON measurements to be enforced on the remote party
2121
#[arg(long, global = true, env = "MEASUREMENTS_FILE")]
22-
measurements_file: Option<PathBuf>,
22+
measurements_file: Option<String>,
2323
/// If no measurements file is specified, a single attestion type to allow
2424
#[arg(long, global = true)]
2525
allowed_remote_attestation_type: Option<String>,
@@ -171,7 +171,9 @@ async fn main() -> anyhow::Result<()> {
171171
}
172172

173173
let measurement_policy = match cli.measurements_file {
174-
Some(server_measurements) => MeasurementPolicy::from_file(server_measurements).await?,
174+
Some(server_measurements) => {
175+
MeasurementPolicy::from_file_or_url(server_measurements).await?
176+
}
175177
None => {
176178
let allowed_server_attestation_type: AttestationType = serde_json::from_value(
177179
serde_json::Value::String(cli.allowed_remote_attestation_type.ok_or(anyhow!(

0 commit comments

Comments
 (0)