Skip to content

Commit 0c9ee2a

Browse files
committed
Split trait into generator and verifier
1 parent 8da3255 commit 0c9ee2a

File tree

3 files changed

+109
-58
lines changed

3 files changed

+109
-58
lines changed

src/attestation.rs

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ use x509_parser::prelude::*;
1010

1111
const PCS_URL: &str = "https://api.trustedservices.intel.com";
1212

13-
/// Represents a CVM technology with quote generation and verification
14-
pub trait AttestationPlatform: Clone + Send + 'static {
13+
pub trait QuoteGenerator: Clone + Send + 'static {
1514
/// Whether this is CVM attestation. This should always return true except for the [NoAttestation] case.
1615
///
1716
/// When false, allows TLS client to be configured without client authentication
@@ -23,6 +22,13 @@ pub trait AttestationPlatform: Clone + Send + 'static {
2322
cert_chain: &[CertificateDer<'_>],
2423
exporter: [u8; 32],
2524
) -> Result<Vec<u8>, AttestationError>;
25+
}
26+
27+
pub trait QuoteVerifier: Clone + Send + 'static {
28+
/// Whether this is CVM attestation. This should always return true except for the [NoAttestation] case.
29+
///
30+
/// When false, allows TLS client to be configured without client authentication
31+
fn is_cvm(&self) -> bool;
2632

2733
/// Verify the given attestation payload
2834
fn verify_attestation(
@@ -33,10 +39,33 @@ pub trait AttestationPlatform: Clone + Send + 'static {
3339
) -> impl Future<Output = Result<(), AttestationError>> + Send;
3440
}
3541

42+
// /// Represents a CVM technology with quote generation and verification
43+
// pub trait AttestationPlatform: Clone + Send + 'static {
44+
// /// Whether this is CVM attestation. This should always return true except for the [NoAttestation] case.
45+
// ///
46+
// /// When false, allows TLS client to be configured without client authentication
47+
// fn is_cvm(&self) -> bool;
48+
//
49+
// /// Generate an attestation
50+
// fn create_attestation(
51+
// &self,
52+
// cert_chain: &[CertificateDer<'_>],
53+
// exporter: [u8; 32],
54+
// ) -> Result<Vec<u8>, AttestationError>;
55+
//
56+
// /// Verify the given attestation payload
57+
// fn verify_attestation(
58+
// &self,
59+
// input: Vec<u8>,
60+
// cert_chain: &[CertificateDer<'_>],
61+
// exporter: [u8; 32],
62+
// ) -> impl Future<Output = Result<(), AttestationError>> + Send;
63+
// }
64+
3665
#[derive(Clone)]
37-
pub struct DcapTdxAttestation;
66+
pub struct DcapTdxQuoteGenerator;
3867

39-
impl AttestationPlatform for DcapTdxAttestation {
68+
impl QuoteGenerator for DcapTdxQuoteGenerator {
4069
fn is_cvm(&self) -> bool {
4170
true
4271
}
@@ -50,6 +79,15 @@ impl AttestationPlatform for DcapTdxAttestation {
5079

5180
Ok(generate_quote(quote_input)?)
5281
}
82+
}
83+
84+
#[derive(Clone)]
85+
pub struct DcapTdxQuoteVerifier;
86+
87+
impl QuoteVerifier for DcapTdxQuoteVerifier {
88+
fn is_cvm(&self) -> bool {
89+
true
90+
}
5391

5492
fn verify_attestation(
5593
&self,
@@ -59,25 +97,29 @@ impl AttestationPlatform for DcapTdxAttestation {
5997
) -> impl Future<Output = Result<(), AttestationError>> + Send {
6098
async move {
6199
let quote_input = compute_report_input(cert_chain, exporter)?;
62-
63-
let now = std::time::SystemTime::now()
64-
.duration_since(std::time::UNIX_EPOCH)
65-
.unwrap()
66-
.as_secs();
67-
let quote = Quote::parse(&input).unwrap();
68-
let ca = quote.ca().unwrap();
69-
let fmspc = hex::encode_upper(quote.fmspc().unwrap());
70-
let collateral = get_collateral_for_fmspc(PCS_URL, fmspc, ca, false)
71-
.await
72-
.unwrap();
73-
74100
// In tests we use mock quotes which will fail to verify
75101
if cfg!(not(test)) {
102+
let now = std::time::SystemTime::now()
103+
.duration_since(std::time::UNIX_EPOCH)
104+
.unwrap()
105+
.as_secs();
106+
let quote = Quote::parse(&input).unwrap();
107+
let ca = quote.ca().unwrap();
108+
let fmspc = hex::encode_upper(quote.fmspc().unwrap());
109+
let collateral = get_collateral_for_fmspc(PCS_URL, fmspc, ca, false)
110+
.await
111+
.unwrap();
76112
let _verified_report = dcap_qvl::verify::verify(&input, &collateral, now).unwrap();
77-
}
78-
let quote = Quote::parse(&input).unwrap();
79-
if get_quote_input_data(quote.report) != quote_input {
80-
return Err(AttestationError::InputMismatch);
113+
114+
let quote = Quote::parse(&input).unwrap();
115+
if get_quote_input_data(quote.report) != quote_input {
116+
return Err(AttestationError::InputMismatch);
117+
}
118+
} else {
119+
let quote = tdx_quote::Quote::from_bytes(&input).unwrap();
120+
if quote.report_input_data() != quote_input {
121+
return Err(AttestationError::InputMismatch);
122+
}
81123
}
82124

83125
Ok(())
@@ -109,9 +151,9 @@ pub fn compute_report_input(
109151

110152
/// For no CVM platform (eg: for one-sided remote-attested TLS)
111153
#[derive(Clone)]
112-
pub struct NoAttestation;
154+
pub struct NoQuoteGenerator;
113155

114-
impl AttestationPlatform for NoAttestation {
156+
impl QuoteGenerator for NoQuoteGenerator {
115157
fn is_cvm(&self) -> bool {
116158
false
117159
}
@@ -124,7 +166,16 @@ impl AttestationPlatform for NoAttestation {
124166
) -> Result<Vec<u8>, AttestationError> {
125167
Ok(Vec::new())
126168
}
169+
}
127170

171+
/// For no CVM platform (eg: for one-sided remote-attested TLS)
172+
#[derive(Clone)]
173+
pub struct NoQuoteVerifier;
174+
175+
impl QuoteVerifier for NoQuoteVerifier {
176+
fn is_cvm(&self) -> bool {
177+
false
178+
}
128179
/// Ensure that an empty attestation is given
129180
async fn verify_attestation(
130181
&self,

src/lib.rs

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
mod attestation;
22

33
use attestation::AttestationError;
4-
pub use attestation::{AttestationPlatform, DcapTdxAttestation, NoAttestation};
4+
pub use attestation::{
5+
DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator,
6+
QuoteVerifier,
7+
};
58
use thiserror::Error;
69
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
710

@@ -29,8 +32,8 @@ pub struct TlsCertAndKey {
2932

3033
struct Proxy<L, R>
3134
where
32-
L: AttestationPlatform,
33-
R: AttestationPlatform,
35+
L: QuoteGenerator,
36+
R: QuoteVerifier,
3437
{
3538
/// The underlying TCP listener
3639
listener: TcpListener,
@@ -43,8 +46,8 @@ where
4346
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
4447
pub struct ProxyServer<L, R>
4548
where
46-
L: AttestationPlatform,
47-
R: AttestationPlatform,
49+
L: QuoteGenerator,
50+
R: QuoteVerifier,
4851
{
4952
inner: Proxy<L, R>,
5053
/// The certificate chain
@@ -55,7 +58,7 @@ where
5558
target: SocketAddr,
5659
}
5760

58-
impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
61+
impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
5962
pub async fn new(
6063
cert_and_key: TlsCertAndKey,
6164
local: impl ToSocketAddrs,
@@ -215,8 +218,8 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
215218

216219
pub struct ProxyClient<L, R>
217220
where
218-
L: AttestationPlatform,
219-
R: AttestationPlatform,
221+
L: QuoteGenerator,
222+
R: QuoteVerifier,
220223
{
221224
inner: Proxy<L, R>,
222225
connector: TlsConnector,
@@ -226,7 +229,7 @@ where
226229
cert_chain: Option<Vec<CertificateDer<'static>>>,
227230
}
228231

229-
impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
232+
impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
230233
pub async fn new(
231234
cert_and_key: Option<TlsCertAndKey>,
232235
address: impl ToSocketAddrs,
@@ -390,7 +393,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
390393
}
391394

392395
/// Just get the attested remote certificate, with no client authentication
393-
pub async fn get_tls_cert<R: AttestationPlatform>(
396+
pub async fn get_tls_cert<R: QuoteVerifier>(
394397
server_name: String,
395398
remote_attestation_platform: R,
396399
) -> Result<Vec<CertificateDer<'static>>, ProxyError> {
@@ -406,7 +409,7 @@ pub async fn get_tls_cert<R: AttestationPlatform>(
406409
.await
407410
}
408411

409-
async fn get_tls_cert_with_config<R: AttestationPlatform>(
412+
async fn get_tls_cert_with_config<R: QuoteVerifier>(
410413
server_name: String,
411414
remote_attestation_platform: R,
412415
client_config: Arc<ClientConfig>,
@@ -516,8 +519,8 @@ mod tests {
516519
server_config,
517520
"127.0.0.1:0",
518521
target_addr,
519-
DcapTdxAttestation,
520-
NoAttestation,
522+
DcapTdxQuoteGenerator,
523+
NoQuoteVerifier,
521524
)
522525
.await
523526
.unwrap();
@@ -532,8 +535,8 @@ mod tests {
532535
client_config,
533536
"127.0.0.1:0".to_string(),
534537
proxy_addr.to_string(),
535-
NoAttestation,
536-
DcapTdxAttestation,
538+
NoQuoteGenerator,
539+
DcapTdxQuoteVerifier,
537540
None,
538541
)
539542
.await
@@ -579,8 +582,8 @@ mod tests {
579582
server_tls_server_config,
580583
"127.0.0.1:0",
581584
target_addr,
582-
DcapTdxAttestation,
583-
DcapTdxAttestation,
585+
DcapTdxQuoteGenerator,
586+
DcapTdxQuoteVerifier,
584587
)
585588
.await
586589
.unwrap();
@@ -595,8 +598,8 @@ mod tests {
595598
client_tls_client_config,
596599
"127.0.0.1:0",
597600
proxy_addr.to_string(),
598-
DcapTdxAttestation,
599-
DcapTdxAttestation,
601+
DcapTdxQuoteGenerator,
602+
DcapTdxQuoteVerifier,
600603
Some(client_cert_chain),
601604
)
602605
.await
@@ -625,15 +628,13 @@ mod tests {
625628
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
626629
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
627630

628-
let local_attestation_platform = DcapTdxAttestation;
629-
630631
let proxy_server = ProxyServer::new_with_tls_config(
631632
cert_chain,
632633
server_config,
633634
"127.0.0.1:0",
634635
target_addr,
635-
local_attestation_platform,
636-
NoAttestation,
636+
DcapTdxQuoteGenerator,
637+
NoQuoteVerifier,
637638
)
638639
.await
639640
.unwrap();
@@ -648,8 +649,8 @@ mod tests {
648649
client_config,
649650
"127.0.0.1:0",
650651
proxy_server_addr.to_string(),
651-
NoAttestation,
652-
DcapTdxAttestation,
652+
NoQuoteGenerator,
653+
DcapTdxQuoteVerifier,
653654
None,
654655
)
655656
.await
@@ -676,15 +677,13 @@ mod tests {
676677
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
677678
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
678679

679-
let local_attestation_platform = DcapTdxAttestation;
680-
681680
let proxy_server = ProxyServer::new_with_tls_config(
682681
cert_chain.clone(),
683682
server_config,
684683
"127.0.0.1:0",
685684
target_addr,
686-
local_attestation_platform,
687-
NoAttestation,
685+
DcapTdxQuoteGenerator,
686+
NoQuoteVerifier,
688687
)
689688
.await
690689
.unwrap();
@@ -697,7 +696,7 @@ mod tests {
697696

698697
let retrieved_chain = get_tls_cert_with_config(
699698
proxy_server_addr.to_string(),
700-
DcapTdxAttestation,
699+
DcapTdxQuoteVerifier,
701700
client_config,
702701
)
703702
.await

src/main.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ use std::{fs::File, net::SocketAddr, path::PathBuf};
44
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
55

66
use attested_tls_proxy::{
7-
get_tls_cert, DcapTdxAttestation, NoAttestation, ProxyClient, ProxyServer, TlsCertAndKey,
7+
get_tls_cert, DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier,
8+
ProxyClient, ProxyServer, TlsCertAndKey,
89
};
910

1011
#[derive(Parser, Debug, Clone)]
@@ -83,8 +84,8 @@ async fn main() -> anyhow::Result<()> {
8384
tls_cert_and_chain,
8485
address,
8586
server,
86-
NoAttestation,
87-
DcapTdxAttestation,
87+
NoQuoteGenerator,
88+
DcapTdxQuoteVerifier,
8889
)
8990
.await?;
9091

@@ -102,8 +103,8 @@ async fn main() -> anyhow::Result<()> {
102103
client_auth,
103104
} => {
104105
let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key)?;
105-
let local_attestation = DcapTdxAttestation;
106-
let remote_attestation = NoAttestation;
106+
let local_attestation = DcapTdxQuoteGenerator;
107+
let remote_attestation = NoQuoteVerifier;
107108

108109
let server = ProxyServer::new(
109110
tls_cert_and_chain,
@@ -122,7 +123,7 @@ async fn main() -> anyhow::Result<()> {
122123
}
123124
}
124125
CliCommand::GetTlsCert { server } => {
125-
let cert_chain = get_tls_cert(server, DcapTdxAttestation).await?;
126+
let cert_chain = get_tls_cert(server, DcapTdxQuoteVerifier).await?;
126127
println!("{}", certs_to_pem_string(&cert_chain)?);
127128
}
128129
}

0 commit comments

Comments
 (0)