Skip to content

Commit ef5d5f6

Browse files
committed
Error handling
1 parent 9139814 commit ef5d5f6

File tree

4 files changed

+56
-91
lines changed

4 files changed

+56
-91
lines changed

src/attestation/measurements.rs

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,15 @@ impl Measurements {
9292
let measurements_map: HashMap<u32, String> = serde_json::from_str(input)?;
9393
let measurements_map: HashMap<u32, [u8; 48]> = measurements_map
9494
.into_iter()
95-
.map(|(k, v)| (k, hex::decode(v).unwrap().try_into().unwrap()))
96-
.collect();
95+
.map(|(k, v)| {
96+
Ok((
97+
k,
98+
hex::decode(v)?
99+
.try_into()
100+
.map_err(|_| MeasurementFormatError::BadLength)?,
101+
))
102+
})
103+
.collect::<Result<_, MeasurementFormatError>>()?;
97104

98105
Ok(Self {
99106
platform: PlatformMeasurements {
@@ -127,6 +134,14 @@ pub enum MeasurementFormatError {
127134
MissingValue(String),
128135
#[error("Invalid header value: {0}")]
129136
BadHeaderValue(#[from] InvalidHeaderValue),
137+
#[error("IO: {0}")]
138+
Io(#[from] std::io::Error),
139+
#[error("Attestation type not valid")]
140+
AttestationTypeNotValid,
141+
#[error("Hex: {0}")]
142+
Hex(#[from] hex::FromHexError),
143+
#[error("Expected 48 byte value")]
144+
BadLength,
130145
}
131146

132147
#[derive(Clone, Debug)]
@@ -137,7 +152,9 @@ pub struct MeasurementRecord {
137152
}
138153

139154
/// Given the path to a JSON file containing measurements, return a [Vec<MeasurementRecord>]
140-
pub async fn get_measurements_from_file(measurement_file: PathBuf) -> AttestationVerifier {
155+
pub async fn get_measurements_from_file(
156+
measurement_file: PathBuf,
157+
) -> Result<AttestationVerifier, MeasurementFormatError> {
141158
#[derive(Debug, Deserialize)]
142159
struct MeasurementRecordSimple {
143160
measurement_id: String,
@@ -150,47 +167,42 @@ pub async fn get_measurements_from_file(measurement_file: PathBuf) -> Attestatio
150167
expected: String,
151168
}
152169

153-
let measurements_json = tokio::fs::read(measurement_file).await.unwrap();
170+
let measurements_json = tokio::fs::read(measurement_file).await?;
154171
let measurements_simple: Vec<MeasurementRecordSimple> =
155-
serde_json::from_slice(&measurements_json).unwrap();
172+
serde_json::from_slice(&measurements_json)?;
156173
let mut measurements = Vec::new();
157174
for measurement in measurements_simple {
158175
measurements.push(MeasurementRecord {
159176
measurement_id: measurement.measurement_id,
160177
attestation_type: AttestationType::parse_from_str(&measurement.attestation_type)
161-
.unwrap(),
178+
.map_err(|_| MeasurementFormatError::AttestationTypeNotValid)?,
162179
measurements: Measurements {
163180
platform: PlatformMeasurements {
164-
mrtd: hex::decode(&measurement.measurements["0"].expected)
165-
.unwrap()
181+
mrtd: hex::decode(&measurement.measurements["0"].expected)?
166182
.try_into()
167-
.unwrap(),
168-
rtmr0: hex::decode(&measurement.measurements["1"].expected)
169-
.unwrap()
183+
.map_err(|_| MeasurementFormatError::BadLength)?,
184+
rtmr0: hex::decode(&measurement.measurements["1"].expected)?
170185
.try_into()
171-
.unwrap(),
186+
.map_err(|_| MeasurementFormatError::BadLength)?,
172187
},
173188
cvm_image: CvmImageMeasurements {
174-
rtmr1: hex::decode(&measurement.measurements["2"].expected)
175-
.unwrap()
189+
rtmr1: hex::decode(&measurement.measurements["2"].expected)?
176190
.try_into()
177-
.unwrap(),
178-
rtmr2: hex::decode(&measurement.measurements["3"].expected)
179-
.unwrap()
191+
.map_err(|_| MeasurementFormatError::BadLength)?,
192+
rtmr2: hex::decode(&measurement.measurements["3"].expected)?
180193
.try_into()
181-
.unwrap(),
182-
rtmr3: hex::decode(&measurement.measurements["4"].expected)
183-
.unwrap()
194+
.map_err(|_| MeasurementFormatError::BadLength)?,
195+
rtmr3: hex::decode(&measurement.measurements["4"].expected)?
184196
.try_into()
185-
.unwrap(),
197+
.map_err(|_| MeasurementFormatError::BadLength)?,
186198
},
187199
},
188200
});
189201
}
190202

191-
AttestationVerifier {
203+
Ok(AttestationVerifier {
192204
accepted_measurements: measurements,
193-
}
205+
})
194206
}
195207

196208
#[cfg(test)]
@@ -199,6 +211,8 @@ mod tests {
199211

200212
#[tokio::test]
201213
async fn test_read_measurements_file() {
202-
get_measurements_from_file("test-assets/measurements.json".into()).await;
214+
get_measurements_from_file("test-assets/measurements.json".into())
215+
.await
216+
.unwrap();
203217
}
204218
}

src/attestation/mod.rs

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ impl AttestationVerifier {
155155
exporter: [u8; 32],
156156
) -> Result<Option<Measurements>, AttestationError> {
157157
let attestation_type =
158-
AttestationType::parse_from_str(&attestation_payload.attestation_type).unwrap();
159-
println!("handling {attestation_type}");
158+
AttestationType::parse_from_str(&attestation_payload.attestation_type)?;
160159

161160
let measurements = match attestation_type {
162161
AttestationType::DcapTdx => {
@@ -188,22 +187,6 @@ impl AttestationVerifier {
188187
}
189188
}
190189

191-
// /// Defines how to verify a quote
192-
// pub trait QuoteVerifier: Sync + Send + 'static {
193-
// /// Type of attestation used
194-
// fn attestation_type(&self) -> AttestationType;
195-
//
196-
// /// Verify the given attestation payload
197-
// fn verify_attestation(
198-
// &self,
199-
// input: Vec<u8>,
200-
// cert_chain: &[CertificateDer<'_>],
201-
// exporter: [u8; 32],
202-
// ) -> Pin<
203-
// Box<dyn Future<Output = Result<Option<Measurements>, AttestationError>> + Send + 'static>,
204-
// >;
205-
// }
206-
207190
/// Quote generation using configfs_tsm
208191
#[derive(Clone)]
209192
pub struct DcapTdxQuoteGenerator {
@@ -323,31 +306,6 @@ impl QuoteGenerator for NoQuoteGenerator {
323306
}
324307
}
325308

326-
/// For no CVM platform (eg: for one-sided remote-attested TLS)
327-
// #[derive(Clone)]
328-
// pub struct NoQuoteVerifier;
329-
//
330-
// impl QuoteVerifier for NoQuoteVerifier {
331-
// /// Type of attestation used
332-
// fn attestation_type(&self) -> AttestationType {
333-
// AttestationType::None
334-
// }
335-
//
336-
// /// Ensure that an empty attestation is given
337-
// async fn verify_attestation(
338-
// &self,
339-
// input: Vec<u8>,
340-
// _cert_chain: &[CertificateDer<'_>],
341-
// _exporter: [u8; 32],
342-
// ) -> Result<Option<Measurements>, AttestationError> {
343-
// if input.is_empty() {
344-
// Ok(None)
345-
// } else {
346-
// Err(AttestationError::AttestationGivenWhenNoneExpected)
347-
// }
348-
// }
349-
// }
350-
351309
/// Create a mock quote for testing on non-confidential hardware
352310
#[cfg(test)]
353311
fn generate_quote(input: [u8; 64]) -> Result<Vec<u8>, QuoteGenerationError> {

src/lib.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ impl ProxyServer {
188188
&cert_chain,
189189
exporter,
190190
local_quote_generator,
191-
)?)
192-
.unwrap()
191+
)?)?
193192
} else {
194193
Vec::new()
195194
};
@@ -209,8 +208,7 @@ impl ProxyServer {
209208

210209
let (measurements, remote_attestation_type) = if attestation_verifier.has_remote_attestion()
211210
{
212-
let remote_attestation_payload: AttesationPayload =
213-
serde_json::from_slice(&buf).unwrap();
211+
let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf)?;
214212

215213
let remote_attestation_type = remote_attestation_payload.attestation_type.clone();
216214
(
@@ -502,9 +500,9 @@ impl ProxyClient {
502500
let mut buf = vec![0; length];
503501
tls_stream.read_exact(&mut buf).await?;
504502

505-
let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf).unwrap();
503+
let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf)?;
506504
let remote_attestation_type =
507-
AttestationType::parse_from_str(&remote_attestation_payload.attestation_type).unwrap();
505+
AttestationType::parse_from_str(&remote_attestation_payload.attestation_type)?;
508506

509507
let measurements = attestation_verifier
510508
.verify_attestation(remote_attestation_payload, &remote_cert_chain, exporter)
@@ -515,8 +513,7 @@ impl ProxyClient {
515513
&cert_chain.ok_or(ProxyError::NoClientAuth)?,
516514
exporter,
517515
local_quote_generator,
518-
)?)
519-
.unwrap()
516+
)?)?
520517
} else {
521518
Vec::new()
522519
};
@@ -637,7 +634,7 @@ async fn get_tls_cert_with_config(
637634
let mut buf = vec![0; length];
638635
tls_stream.read_exact(&mut buf).await?;
639636

640-
let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf).unwrap();
637+
let remote_attestation_payload: AttesationPayload = serde_json::from_slice(&buf)?;
641638

642639
let _measurements = attestation_verifier
643640
.verify_attestation(remote_attestation_payload, &remote_cert_chain, exporter)
@@ -667,6 +664,8 @@ pub enum ProxyError {
667664
BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError),
668665
#[error("HTTP: {0}")]
669666
Hyper(#[from] hyper::Error),
667+
#[error("JSON: {0}")]
668+
Json(#[from] serde_json::Error),
670669
}
671670

672671
/// Given a byte array, encode its length as a 4 byte big endian u32

src/main.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ enum CliCommand {
5555
// Name: "tls-ca-certificate",
5656
// Usage: "additional CA certificate to verify against (PEM) [default=no additional TLS certs]. Only valid with --verify-tls.",
5757
//
58-
// Name: "override-azurev6-tcbinfo",
59-
// Value: false,
60-
// EnvVars: []string{"OVERRIDE_AZUREV6_TCBINFO"},
61-
// Usage: "Allows Azure's V6 instance outdated SEAM Loader",
62-
//
6358
// Name: "dev-dummy-dcap",
6459
// EnvVars: []string{"DEV_DUMMY_DCAP"},
6560
// Usage: "URL of the remote dummy DCAP service. Only with --client-attestation-type dummy.",
@@ -94,13 +89,6 @@ enum CliCommand {
9489
// Value: "",
9590
// Usage: "address to listen on for health checks",
9691
//
97-
//
98-
//
99-
// Name: "override-azurev6-tcbinfo",
100-
// Value: false,
101-
// EnvVars: []string{"OVERRIDE_AZUREV6_TCBINFO"},
102-
// Usage: "Allows Azure's V6 instance outdated SEAM Loader",
103-
//
10492
// Name: "dev-dummy-dcap",
10593
// EnvVars: []string{"DEV_DUMMY_DCAP"},
10694
// Usage: "URL of the remote dummy DCAP service. Only with --server-attestation-type dummy.",
@@ -143,7 +131,9 @@ async fn main() -> anyhow::Result<()> {
143131
};
144132

145133
let attestation_verifier = match server_measurements {
146-
Some(server_measurements) => get_measurements_from_file(server_measurements).await,
134+
Some(server_measurements) => {
135+
get_measurements_from_file(server_measurements).await?
136+
}
147137
None => AttestationVerifier::do_not_verify(),
148138
};
149139

@@ -187,7 +177,9 @@ async fn main() -> anyhow::Result<()> {
187177
let local_attestation_generator = server_attestation_type.get_quote_generator()?;
188178

189179
let attestation_verifier = match client_measurements {
190-
Some(client_measurements) => get_measurements_from_file(client_measurements).await,
180+
Some(client_measurements) => {
181+
get_measurements_from_file(client_measurements).await?
182+
}
191183
None => AttestationVerifier::do_not_verify(),
192184
};
193185

@@ -212,7 +204,9 @@ async fn main() -> anyhow::Result<()> {
212204
server_measurements,
213205
} => {
214206
let attestation_verifier = match server_measurements {
215-
Some(server_measurements) => get_measurements_from_file(server_measurements).await,
207+
Some(server_measurements) => {
208+
get_measurements_from_file(server_measurements).await?
209+
}
216210
None => AttestationVerifier::do_not_verify(),
217211
};
218212
let cert_chain = get_tls_cert(server, attestation_verifier).await?;

0 commit comments

Comments
 (0)