Skip to content

Commit e34befc

Browse files
committed
Custom error type for MAA
1 parent 654f23d commit e34befc

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

src/attestation/azure.rs

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
//! Microsoft Azure Attestation (MAA) evidence generation and verification
2+
use std::string::FromUtf8Error;
3+
24
use az_tdx_vtpm::{hcl, imds, report, vtpm};
35
use tokio_rustls::rustls::pki_types::CertificateDer;
46
// use openssl::pkey::{PKey, Public};
57
use base64::{engine::general_purpose::URL_SAFE as BASE64_URL_SAFE, Engine as _};
68
use reqwest::Client;
79
use serde::Serialize;
10+
use thiserror::Error;
811

912
use crate::attestation::{compute_report_input, AttestationError};
1013

@@ -15,12 +18,13 @@ use crate::attestation::{compute_report_input, AttestationError};
1518
pub async fn create_azure_attestation(
1619
cert_chain: &[CertificateDer<'_>],
1720
exporter: [u8; 32],
18-
) -> Result<Vec<u8>, AttestationError> {
21+
) -> Result<Vec<u8>, MaaError> {
1922
let maa_endpoint = "todo".to_string();
2023
let aad_access_token = "todo".to_string();
21-
let input_data = compute_report_input(cert_chain, exporter)?;
24+
let input_data = compute_report_input(cert_chain, exporter)
25+
.map_err(|e| MaaError::InputData(e.to_string()))?;
2226

23-
let td_report = report::get_report().unwrap();
27+
let td_report = report::get_report()?;
2428

2529
// let mrtd = td_report.tdinfo.mrtd;
2630
// let rtmr0 = td_report.tdinfo.rtrm[0];
@@ -29,10 +33,10 @@ pub async fn create_azure_attestation(
2933
// let rtmr3 = td_report.tdinfo.rtrm[3];
3034

3135
// This makes a request to Azure Instance metadata service and gives us a binary response
32-
let td_quote_bytes = imds::get_td_quote(&td_report).unwrap();
36+
let td_quote_bytes = imds::get_td_quote(&td_report)?;
3337

34-
let hcl_report_bytes = vtpm::get_report_with_report_data(&input_data).unwrap();
35-
let hcl_report = hcl::HclReport::new(hcl_report_bytes).unwrap();
38+
let hcl_report_bytes = vtpm::get_report_with_report_data(&input_data)?;
39+
let hcl_report = hcl::HclReport::new(hcl_report_bytes)?;
3640
let hcl_var_data = hcl_report.var_data();
3741

3842
// let bytes = vtpm::get_report().unwrap();
@@ -61,10 +65,8 @@ pub async fn create_azure_attestation(
6165
}),
6266
nonce: Some("my-app-nonce-or-session-id".to_string()),
6367
};
64-
let body_bytes = serde_json::to_vec(&body).unwrap();
65-
let jwt_token = call_tdxvm_attestation(maa_endpoint, aad_access_token, body_bytes)
66-
.await
67-
.unwrap();
68+
let body_bytes = serde_json::to_vec(&body)?;
69+
let jwt_token = call_tdxvm_attestation(maa_endpoint, aad_access_token, body_bytes).await?;
6870
Ok(jwt_token.as_bytes().to_vec())
6971
}
7072

@@ -73,7 +75,7 @@ async fn call_tdxvm_attestation(
7375
maa_endpoint: String,
7476
aad_access_token: String,
7577
body_bytes: Vec<u8>,
76-
) -> Result<String, Box<dyn std::error::Error>> {
78+
) -> Result<String, MaaError> {
7779
let url = format!("{}/attest/TdxVm?api-version=2025-06-01", maa_endpoint);
7880

7981
let client = Client::new();
@@ -89,7 +91,7 @@ async fn call_tdxvm_attestation(
8991
let text = res.text().await?;
9092

9193
if !status.is_success() {
92-
return Err(format!("MAA attestation failed: {status} {text}").into());
94+
return Err(MaaError::MaaProvider(status, text));
9395
}
9496

9597
#[derive(serde::Deserialize)]
@@ -105,9 +107,10 @@ pub async fn verify_azure_attestation(
105107
input: Vec<u8>,
106108
cert_chain: &[CertificateDer<'_>],
107109
exporter: [u8; 32],
108-
) -> Result<super::measurements::Measurements, AttestationError> {
109-
let _input_data = compute_report_input(cert_chain, exporter)?;
110-
let token = String::from_utf8(input).unwrap();
110+
) -> Result<super::measurements::Measurements, MaaError> {
111+
let _input_data = compute_report_input(cert_chain, exporter)
112+
.map_err(|e| MaaError::InputData(e.to_string()))?;
113+
let token = String::from_utf8(input)?;
111114

112115
decode_jwt(&token).await.unwrap();
113116

@@ -141,6 +144,28 @@ struct TdxVmRequest<'a> {
141144
nonce: Option<String>,
142145
}
143146

147+
#[derive(Error, Debug)]
148+
pub enum MaaError {
149+
#[error("Failed to build input data: {0}")]
150+
InputData(String),
151+
#[error("Report: {0}")]
152+
Report(#[from] az_tdx_vtpm::report::ReportError),
153+
#[error("IMDS: {0}")]
154+
Imds(#[from] imds::ImdsError),
155+
#[error("vTPM report: {0}")]
156+
VtpmReport(#[from] az_tdx_vtpm::vtpm::ReportError),
157+
#[error("HCL: {0}")]
158+
Hcl(#[from] hcl::HclError),
159+
#[error("JSON: {0}")]
160+
Json(#[from] serde_json::Error),
161+
#[error("HTTP Client: {0}")]
162+
HttpClient(#[from] reqwest::Error),
163+
#[error("MAA provider response: {0} - {1}")]
164+
MaaProvider(http::StatusCode, String),
165+
#[error("Token is bad UTF8: {0}")]
166+
BadUtf8(#[from] FromUtf8Error),
167+
}
168+
144169
#[cfg(test)]
145170
mod tests {
146171
use super::*;

src/attestation/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl AttestationGenerator {
122122
match self.attestation_type {
123123
AttestationType::None => Ok(Vec::new()),
124124
AttestationType::AzureTdx => {
125-
azure::create_azure_attestation(cert_chain, exporter).await
125+
Ok(azure::create_azure_attestation(cert_chain, exporter).await?)
126126
}
127127
AttestationType::Dummy => Err(AttestationError::AttestationTypeNotSupported),
128128
_ => dcap::create_dcap_attestation(cert_chain, exporter).await,
@@ -291,4 +291,6 @@ pub enum AttestationError {
291291
AttestationTypeNotAccepted,
292292
#[error("Measurements not accepted")]
293293
MeasurementsNotAccepted,
294+
#[error("MAA: {0}")]
295+
Maa(#[from] azure::MaaError),
294296
}

0 commit comments

Comments
 (0)