Skip to content

Commit 4ffcf95

Browse files
committed
Use attested-tls-client from refactored module
1 parent a4690d3 commit 4ffcf95

File tree

4 files changed

+114
-326
lines changed

4 files changed

+114
-326
lines changed

src/attested_tls.rs

Lines changed: 23 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ pub struct TlsCertAndKey {
3838
}
3939

4040
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
41+
#[derive(Clone)]
4142
pub struct AttestedTlsServer {
4243
/// The underlying TCP listener
43-
listener: TcpListener,
44+
pub listener: Arc<TcpListener>,
4445
/// Quote generation type to use (including none)
4546
attestation_generator: AttestationGenerator,
4647
/// Verifier for remote attestation (including none)
@@ -102,15 +103,15 @@ impl AttestedTlsServer {
102103
let listener = TcpListener::bind(local).await?;
103104

104105
Ok(Self {
105-
listener,
106+
listener: listener.into(),
106107
attestation_generator,
107108
attestation_verifier,
108109
acceptor,
109110
cert_chain,
110111
})
111112
}
112113

113-
/// Accept an incoming connection and handle it in a seperate task
114+
/// Accept an incoming connection and do an attestation exchange
114115
pub async fn accept(
115116
&self,
116117
) -> Result<
@@ -123,18 +124,7 @@ impl AttestedTlsServer {
123124
> {
124125
let (inbound, _client_addr) = self.listener.accept().await?;
125126

126-
let acceptor = self.acceptor.clone();
127-
let cert_chain = self.cert_chain.clone();
128-
let attestation_generator = self.attestation_generator.clone();
129-
let attestation_verifier = self.attestation_verifier.clone();
130-
Ok(Self::handle_connection(
131-
inbound,
132-
acceptor,
133-
cert_chain,
134-
attestation_generator,
135-
attestation_verifier,
136-
)
137-
.await?)
127+
self.handle_connection(inbound).await
138128
}
139129

140130
/// Helper to get the socket address of the underlying TCP listener
@@ -143,12 +133,13 @@ impl AttestedTlsServer {
143133
}
144134

145135
/// Handle an incoming connection from a proxy-client
146-
async fn handle_connection(
136+
pub async fn handle_connection(
137+
&self,
147138
inbound: TcpStream,
148-
acceptor: TlsAcceptor,
149-
cert_chain: Vec<CertificateDer<'static>>,
150-
attestation_generator: AttestationGenerator,
151-
attestation_verifier: AttestationVerifier,
139+
// acceptor: TlsAcceptor,
140+
// cert_chain: Vec<CertificateDer<'static>>,
141+
// attestation_generator: AttestationGenerator,
142+
// attestation_verifier: AttestationVerifier,
152143
) -> Result<
153144
(
154145
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
@@ -160,7 +151,7 @@ impl AttestedTlsServer {
160151
tracing::debug!("attested-tls-server accepted connection");
161152

162153
// Do TLS handshake
163-
let mut tls_stream = acceptor.accept(inbound).await?;
154+
let mut tls_stream = self.acceptor.accept(inbound).await?;
164155
let (_io, connection) = tls_stream.get_ref();
165156

166157
// Ensure that we agreed a protocol
@@ -176,13 +167,14 @@ impl AttestedTlsServer {
176167
None, // context
177168
)?;
178169

179-
let input_data = compute_report_input(Some(&cert_chain), exporter)?;
170+
let input_data = compute_report_input(Some(&self.cert_chain), exporter)?;
180171

181172
// Get the TLS certficate chain of the client, if there is one
182173
let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned());
183174

184175
// If we are in a CVM, generate an attestation
185-
let attestation = attestation_generator
176+
let attestation = self
177+
.attestation_generator
186178
.generate_attestation(input_data)
187179
.await?
188180
.encode();
@@ -205,10 +197,10 @@ impl AttestedTlsServer {
205197
let remote_attestation_type = remote_attestation_message.attestation_type;
206198

207199
// If we expect an attestaion from the client, verify it and get measurements
208-
let measurements = if attestation_verifier.has_remote_attestion() {
200+
let measurements = if self.attestation_verifier.has_remote_attestion() {
209201
let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?;
210202

211-
attestation_verifier
203+
self.attestation_verifier
212204
.verify_attestation(remote_attestation_message, remote_input_data)
213205
.await?
214206
} else {
@@ -220,9 +212,8 @@ impl AttestedTlsServer {
220212
}
221213

222214
/// A proxy client which forwards http traffic to a proxy-server
215+
#[derive(Clone)]
223216
pub struct AttestedTlsClient {
224-
/// The underlying TCP listener
225-
listener: TcpListener,
226217
/// The connector for making TLS connections with out configuration
227218
connector: TlsConnector,
228219
/// Quote generation type to use (including none)
@@ -234,9 +225,10 @@ pub struct AttestedTlsClient {
234225
}
235226

236227
impl std::fmt::Debug for AttestedTlsClient {
228+
// TODO add other fields
237229
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238230
f.debug_struct("AttestedTlsClient")
239-
.field("listener", &self.listener)
231+
.field("attestation_verifier", &self.attestation_verifier)
240232
.finish()
241233
}
242234
}
@@ -245,7 +237,6 @@ impl AttestedTlsClient {
245237
/// Start with optional TLS client auth
246238
pub async fn new(
247239
cert_and_key: Option<TlsCertAndKey>,
248-
address: impl ToSocketAddrs,
249240
attestation_generator: AttestationGenerator,
250241
attestation_verifier: AttestationVerifier,
251242
remote_certificate: Option<CertificateDer<'static>>,
@@ -281,7 +272,6 @@ impl AttestedTlsClient {
281272

282273
Self::new_with_tls_config(
283274
client_config.into(),
284-
address,
285275
attestation_generator,
286276
attestation_verifier,
287277
cert_and_key.map(|c| c.cert_chain),
@@ -291,32 +281,23 @@ impl AttestedTlsClient {
291281

292282
/// Create a new proxy client with given TLS configuration
293283
///
294-
/// This is private as it allows dangerous configuration but is used in tests
295-
async fn new_with_tls_config(
284+
/// This not fully public as it allows dangerous configuration but is used in tests
285+
pub(crate) async fn new_with_tls_config(
296286
client_config: Arc<ClientConfig>,
297-
local: impl ToSocketAddrs,
298287
attestation_generator: AttestationGenerator,
299288
attestation_verifier: AttestationVerifier,
300289
cert_chain: Option<Vec<CertificateDer<'static>>>,
301290
) -> Result<Self, AttestedTlsError> {
302-
// Setup TCP server and TLS client
303-
let listener = TcpListener::bind(local).await?;
304291
let connector = TlsConnector::from(client_config.clone());
305292

306293
Ok(Self {
307-
listener,
308294
connector,
309295
attestation_generator,
310296
attestation_verifier,
311297
cert_chain,
312298
})
313299
}
314300

315-
/// Helper to return the local socket address from the underlying TCP listener
316-
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
317-
self.listener.local_addr()
318-
}
319-
320301
/// Connect to the attested-tls-server, do TLS handshake and remote attestation
321302
pub async fn connect(
322303
&self,
@@ -425,7 +406,7 @@ pub async fn get_tls_cert(
425406
get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await
426407
}
427408

428-
async fn get_tls_cert_with_config(
409+
pub(crate) async fn get_tls_cert_with_config(
429410
server_name: String,
430411
attestation_verifier: AttestationVerifier,
431412
client_config: Arc<ClientConfig>,

0 commit comments

Comments
 (0)