Skip to content

Commit 65c96e7

Browse files
committed
Error handling
1 parent 52068f8 commit 65c96e7

File tree

2 files changed

+60
-34
lines changed

2 files changed

+60
-34
lines changed

src/attestation.rs

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,40 @@
1-
use crate::AttestationError;
21
use sha2::{Digest, Sha256};
2+
use thiserror::Error;
33
use tokio_rustls::rustls::pki_types::CertificateDer;
44
use x509_parser::prelude::*;
55

6+
/// Represents a CVM technology with quote generation and verification
67
pub trait AttestationPlatform: Clone + Send + 'static {
7-
fn create_attestation(&self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec<u8>;
8+
fn create_attestation(
9+
&self,
10+
cert_chain: &[CertificateDer<'_>],
11+
exporter: [u8; 32],
12+
) -> Result<Vec<u8>, AttestationError>;
813

914
fn verify_attestation(
1015
&self,
1116
input: Vec<u8>,
1217
cert_chain: &[CertificateDer<'_>],
1318
exporter: [u8; 32],
14-
) -> bool;
19+
) -> Result<(), AttestationError>;
1520
}
1621

22+
/// For testing
1723
#[derive(Clone)]
1824
pub struct MockAttestation;
1925

2026
impl AttestationPlatform for MockAttestation {
2127
/// Mocks creating an attestation
22-
fn create_attestation(&self, cert_chain: &[CertificateDer<'_>], exporter: [u8; 32]) -> Vec<u8> {
28+
fn create_attestation(
29+
&self,
30+
cert_chain: &[CertificateDer<'_>],
31+
exporter: [u8; 32],
32+
) -> Result<Vec<u8>, AttestationError> {
2333
let mut quote_input = [0u8; 64];
24-
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap();
34+
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?;
2535
quote_input[..32].copy_from_slice(&pki_hash);
2636
quote_input[32..].copy_from_slice(&exporter);
27-
quote_input.to_vec()
37+
Ok(quote_input.to_vec())
2838
}
2939

3040
/// Mocks verifying an attestation
@@ -33,16 +43,20 @@ impl AttestationPlatform for MockAttestation {
3343
input: Vec<u8>,
3444
cert_chain: &[CertificateDer<'_>],
3545
exporter: [u8; 32],
36-
) -> bool {
46+
) -> Result<(), AttestationError> {
3747
let mut quote_input = [0u8; 64];
38-
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain).unwrap();
48+
let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?;
3949
quote_input[..32].copy_from_slice(&pki_hash);
4050
quote_input[32..].copy_from_slice(&exporter);
4151

42-
input == quote_input
52+
if input != quote_input {
53+
return Err(AttestationError::InputMismatch);
54+
}
55+
Ok(())
4356
}
4457
}
4558

59+
/// For no CVM platform (eg: for one-sided remote-attested TLS)
4660
#[derive(Clone)]
4761
pub struct NoAttestation;
4862

@@ -52,18 +66,22 @@ impl AttestationPlatform for NoAttestation {
5266
&self,
5367
_cert_chain: &[CertificateDer<'_>],
5468
_exporter: [u8; 32],
55-
) -> Vec<u8> {
56-
Vec::new()
69+
) -> Result<Vec<u8>, AttestationError> {
70+
Ok(Vec::new())
5771
}
5872

5973
/// Mocks verifying an attestation
6074
fn verify_attestation(
6175
&self,
62-
_input: Vec<u8>,
76+
input: Vec<u8>,
6377
_cert_chain: &[CertificateDer<'_>],
6478
_exporter: [u8; 32],
65-
) -> bool {
66-
true
79+
) -> Result<(), AttestationError> {
80+
if input.is_empty() {
81+
Ok(())
82+
} else {
83+
Err(AttestationError::AttestationGivenWhenNoneExpected)
84+
}
6785
}
6886
}
6987

@@ -80,3 +98,18 @@ fn get_pki_hash_from_certificate_chain(
8098
hasher.update(key_bytes);
8199
Ok(hasher.finalize().into())
82100
}
101+
102+
/// An error when generating or verifying an attestation
103+
#[derive(Error, Debug)]
104+
pub enum AttestationError {
105+
#[error("Certificate chain is empty")]
106+
NoCertificate,
107+
#[error("X509 parse: {0}")]
108+
X509Parse(#[from] x509_parser::asn1_rs::Err<x509_parser::error::X509Error>),
109+
#[error("X509: {0}")]
110+
X509(#[from] x509_parser::error::X509Error),
111+
#[error("Quote input is not as expected")]
112+
InputMismatch,
113+
#[error("Configuration mismatch - expected no remote attestation")]
114+
AttestationGivenWhenNoneExpected,
115+
}

src/lib.rs

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ mod attestation;
33
pub use attestation::{AttestationPlatform, MockAttestation, NoAttestation};
44

55
use std::{net::SocketAddr, sync::Arc};
6-
use thiserror::Error;
76
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
87
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
98
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
@@ -118,7 +117,9 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
118117
)
119118
.unwrap();
120119

121-
let attestation = local_attestation_platform.create_attestation(&cert_chain, exporter);
120+
let attestation = local_attestation_platform
121+
.create_attestation(&cert_chain, exporter)
122+
.unwrap();
122123
let attestation_length_prefix = length_prefix(&attestation);
123124

124125
tls_stream
@@ -135,9 +136,9 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
135136
let mut buf = vec![0; length];
136137
tls_stream.read_exact(&mut buf).await.unwrap();
137138

138-
if !remote_attestation_platform.verify_attestation(buf, &cert_chain, exporter) {
139-
panic!("Cannot verify attestation");
140-
};
139+
remote_attestation_platform
140+
.verify_attestation(buf, &cert_chain, exporter)
141+
.unwrap();
141142

142143
let outbound = TcpStream::connect(target).await.unwrap();
143144

@@ -252,11 +253,14 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
252253
let mut buf = vec![0; length];
253254
tls_stream.read_exact(&mut buf).await.unwrap();
254255

255-
if !remote_attestation_platform.verify_attestation(buf, &cert_chain, exporter) {
256-
panic!("Cannot verify attestation");
257-
};
256+
remote_attestation_platform
257+
.verify_attestation(buf, &cert_chain, exporter)
258+
.unwrap();
259+
260+
let attestation = local_attestation_platform
261+
.create_attestation(&cert_chain, exporter)
262+
.unwrap();
258263

259-
let attestation = local_attestation_platform.create_attestation(&cert_chain, exporter);
260264
let attestation_length_prefix = length_prefix(&attestation);
261265

262266
tls_stream
@@ -287,17 +291,6 @@ fn length_prefix(input: &[u8]) -> [u8; 4] {
287291
len.to_be_bytes()
288292
}
289293

290-
/// An error when generating an attestation
291-
#[derive(Error, Debug)]
292-
pub enum AttestationError {
293-
#[error("Certificate chain is empty")]
294-
NoCertificate,
295-
#[error("X509 parse: {0}")]
296-
X509Parse(#[from] x509_parser::asn1_rs::Err<x509_parser::error::X509Error>),
297-
#[error("X509: {0}")]
298-
X509(#[from] x509_parser::error::X509Error),
299-
}
300-
301294
#[cfg(test)]
302295
mod tests {
303296
use super::*;

0 commit comments

Comments
 (0)