Skip to content

Commit 8aec1a1

Browse files
committed
Merge branch 'main' into peg/ws
* main: Force single test thread in CI Make transport agnostic
2 parents 9d2fabf + 8408a36 commit 8aec1a1

File tree

3 files changed

+58
-57
lines changed

3 files changed

+58
-57
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,4 @@ jobs:
3636
run: cargo clippy --workspace -- -D warnings
3737

3838
- name: Run cargo test
39-
run: cargo test --workspace --all-targets
39+
run: cargo test --workspace --all-targets -- --test-threads=1

src/attested_tls.rs

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@ use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
1212
use x509_parser::parse_x509_certificate;
1313

1414
use std::num::TryFromIntError;
15-
use std::{net::SocketAddr, sync::Arc};
16-
use tokio::io::{AsyncReadExt, AsyncWriteExt};
17-
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15+
use std::sync::Arc;
16+
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
1817
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
1918
use tokio_rustls::rustls::RootCertStore;
2019
use tokio_rustls::{
@@ -38,16 +37,14 @@ pub struct TlsCertAndKey {
3837
pub key: PrivateKeyDer<'static>,
3938
}
4039

41-
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
40+
/// A TLS server which makes an attestation exchange following the TLS handshake
4241
#[derive(Clone)]
4342
pub struct AttestedTlsServer {
44-
/// The underlying TCP listener
45-
pub listener: Arc<TcpListener>,
4643
/// Quote generation type to use (including none)
4744
attestation_generator: AttestationGenerator,
4845
/// Verifier for remote attestation (including none)
4946
attestation_verifier: AttestationVerifier,
50-
/// The certificate chain
47+
/// The TLS certificate chain
5148
cert_chain: Vec<CertificateDer<'static>>,
5249
/// For accepting TLS connections
5350
acceptor: TlsAcceptor,
@@ -56,7 +53,6 @@ pub struct AttestedTlsServer {
5653
impl AttestedTlsServer {
5754
pub async fn new(
5855
cert_and_key: TlsCertAndKey,
59-
local: impl ToSocketAddrs,
6056
attestation_generator: AttestationGenerator,
6157
attestation_verifier: AttestationVerifier,
6258
client_auth: bool,
@@ -83,7 +79,6 @@ impl AttestedTlsServer {
8379
Self::new_with_tls_config(
8480
cert_and_key.cert_chain,
8581
server_config.into(),
86-
local,
8782
attestation_generator,
8883
attestation_verifier,
8984
)
@@ -96,55 +91,36 @@ impl AttestedTlsServer {
9691
pub(crate) async fn new_with_tls_config(
9792
cert_chain: Vec<CertificateDer<'static>>,
9893
server_config: Arc<ServerConfig>,
99-
local: impl ToSocketAddrs,
10094
attestation_generator: AttestationGenerator,
10195
attestation_verifier: AttestationVerifier,
10296
) -> Result<Self, AttestedTlsError> {
10397
let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
104-
let listener = TcpListener::bind(local).await?;
10598

10699
Ok(Self {
107-
listener: listener.into(),
108100
attestation_generator,
109101
attestation_verifier,
110102
acceptor,
111103
cert_chain,
112104
})
113105
}
114106

115-
/// Accept an incoming connection and do an attestation exchange
116-
pub async fn accept(
117-
&self,
118-
) -> Result<
119-
(
120-
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
121-
Option<MultiMeasurements>,
122-
AttestationType,
123-
),
124-
AttestedTlsError,
125-
> {
126-
let (inbound, _client_addr) = self.listener.accept().await?;
127-
128-
self.handle_connection(inbound).await
129-
}
130-
131-
/// Helper to get the socket address of the underlying TCP listener
132-
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
133-
self.listener.local_addr()
134-
}
135-
136107
/// Handle an incoming connection from a proxy-client
137-
pub async fn handle_connection(
108+
///
109+
/// This is transport agnostic and will work with any asynchronous stream
110+
pub async fn handle_connection<IO>(
138111
&self,
139-
inbound: TcpStream,
112+
inbound: IO,
140113
) -> Result<
141114
(
142-
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
115+
tokio_rustls::server::TlsStream<IO>,
143116
Option<MultiMeasurements>,
144117
AttestationType,
145118
),
146119
AttestedTlsError,
147-
> {
120+
>
121+
where
122+
IO: AsyncRead + AsyncWrite + Unpin,
123+
{
148124
tracing::debug!("attested-tls-server accepted connection");
149125

150126
// Do TLS handshake
@@ -296,23 +272,29 @@ impl AttestedTlsClient {
296272
})
297273
}
298274

299-
/// Connect to an attested-tls-server, do TLS handshake and attestation exchange
300-
pub async fn connect(
275+
/// Given a connection to an attested TLS server, do a TLS handshake and attestation exchange, and return the TLS
276+
/// stream together with measurement details
277+
///
278+
/// This is transport agnostic and will work with any asynchronous stream
279+
pub async fn connect<IO>(
301280
&self,
302281
target: &str,
282+
outbound: IO,
303283
) -> Result<
304284
(
305-
tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
285+
tokio_rustls::client::TlsStream<IO>,
306286
Option<MultiMeasurements>,
307287
AttestationType,
308288
),
309289
AttestedTlsError,
310-
> {
311-
// Make a TCP client connection and TLS handshake
312-
let out = TcpStream::connect(&target).await?;
290+
>
291+
where
292+
IO: AsyncRead + AsyncWrite + Unpin,
293+
{
294+
// Make a TLS handshake with the given connection
313295
let mut tls_stream = self
314296
.connector
315-
.connect(server_name_from_host(target)?, out)
297+
.connect(server_name_from_host(target)?, outbound)
316298
.await?;
317299

318300
let (_io, server_connection) = tls_stream.get_ref();
@@ -374,12 +356,29 @@ impl AttestedTlsClient {
374356
Ok((tls_stream, measurements, remote_attestation_type))
375357
}
376358

377-
/// Connect to an attested TLS server, retrieve the remote TLS certificate and return it
359+
/// Make a TCP connection, do a TLS handshake and attestation exchange, and return the TLS
360+
/// stream together with measurement details
361+
pub async fn connect_tcp(
362+
&self,
363+
target: &str,
364+
) -> Result<
365+
(
366+
tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
367+
Option<MultiMeasurements>,
368+
AttestationType,
369+
),
370+
AttestedTlsError,
371+
> {
372+
let out = tokio::net::TcpStream::connect(&target).await?;
373+
self.connect(target, out).await
374+
}
375+
376+
/// Connect to an attested TLS server using TCP, retrieve the remote TLS certificate and return it
378377
pub async fn get_tls_cert(
379378
&self,
380379
server_name: &str,
381380
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
382-
let (mut tls_stream, _, _) = self.connect(server_name).await?;
381+
let (mut tls_stream, _, _) = self.connect_tcp(server_name).await?;
383382

384383
let (_io, server_connection) = tls_stream.get_ref();
385384

src/lib.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,11 @@ use tracing::{error, warn};
2222
#[cfg(test)]
2323
mod test_helpers;
2424

25-
use std::net::SocketAddr;
26-
use std::num::TryFromIntError;
27-
use std::time::Duration;
25+
use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration};
2826
use tokio::io;
2927
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
3028
use tokio_rustls::rustls::pki_types::CertificateDer;
3129

32-
#[cfg(test)]
33-
use std::sync::Arc;
3430
#[cfg(test)]
3531
use tokio_rustls::rustls::{ClientConfig, ServerConfig};
3632

@@ -60,6 +56,8 @@ type Http2Sender = hyper::client::conn::http2::SendRequest<hyper::body::Incoming
6056
pub struct ProxyServer {
6157
/// The underlying attested TLS server
6258
attested_tls_server: AttestedTlsServer,
59+
/// The underlying TCP listener
60+
listener: Arc<TcpListener>,
6361
/// The address of the target service we are proxying to
6462
target: SocketAddr,
6563
}
@@ -75,15 +73,17 @@ impl ProxyServer {
7573
) -> Result<Self, ProxyError> {
7674
let attested_tls_server = AttestedTlsServer::new(
7775
cert_and_key,
78-
local,
7976
attestation_generator,
8077
attestation_verifier,
8178
client_auth,
8279
)
8380
.await?;
8481

82+
let listener = TcpListener::bind(local).await?;
83+
8584
Ok(Self {
8685
attested_tls_server,
86+
listener: listener.into(),
8787
target,
8888
})
8989
}
@@ -103,22 +103,24 @@ impl ProxyServer {
103103
let attested_tls_server = AttestedTlsServer::new_with_tls_config(
104104
cert_chain,
105105
server_config,
106-
local,
107106
attestation_generator,
108107
attestation_verifier,
109108
)
110109
.await?;
111110

111+
let listener = TcpListener::bind(local).await?;
112+
112113
Ok(Self {
113114
attested_tls_server,
115+
listener: listener.into(),
114116
target,
115117
})
116118
}
117119

118120
/// Accept an incoming connection and handle it in a seperate task
119121
pub async fn accept(&self) -> Result<(), ProxyError> {
120122
let target = self.target;
121-
let (inbound, _client_addr) = self.attested_tls_server.listener.accept().await?;
123+
let (inbound, _client_addr) = self.listener.accept().await?;
122124
let attested_tls_server = self.attested_tls_server.clone();
123125

124126
tokio::spawn(async move {
@@ -142,7 +144,7 @@ impl ProxyServer {
142144

143145
/// Helper to get the socket address of the underlying TCP listener
144146
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
145-
self.attested_tls_server.local_addr()
147+
self.listener.local_addr()
146148
}
147149

148150
/// Handle an incoming connection from a proxy-client
@@ -472,7 +474,7 @@ impl ProxyClient {
472474
inner: &AttestedTlsClient,
473475
target: &str,
474476
) -> Result<(Http2Sender, Option<MultiMeasurements>, AttestationType), ProxyError> {
475-
let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?;
477+
let (tls_stream, measurements, remote_attestation_type) = inner.connect_tcp(target).await?;
476478

477479
// The attestation exchange is now complete - setup an HTTP client
478480

0 commit comments

Comments
 (0)