Skip to content

Commit a4690d3

Browse files
committed
Use attested-tls-server from refactored module
1 parent 1aed046 commit a4690d3

File tree

3 files changed

+56
-133
lines changed

3 files changed

+56
-133
lines changed

src/attested_tls.rs

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use parity_scale_codec::{Decode, Encode};
55
use sha2::{Digest, Sha256};
66
use thiserror::Error;
77
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
8-
use tracing::{error, warn};
8+
use tracing::error;
99
use x509_parser::parse_x509_certificate;
1010

1111
use std::num::TryFromIntError;
@@ -27,7 +27,7 @@ use crate::attestation::{AttestationExchangeMessage, AttestationVerifier};
2727
pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"];
2828

2929
/// The label used when exporting key material from a TLS session
30-
const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding";
30+
pub(crate) const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding";
3131

3232
/// TLS Credentials
3333
pub struct TlsCertAndKey {
@@ -90,8 +90,8 @@ impl AttestedTlsServer {
9090

9191
/// Start with preconfigured TLS
9292
///
93-
/// This is not public as it allows dangerous configuration
94-
async fn new_with_tls_config(
93+
/// This is not fully public as it allows dangerous configuration
94+
pub(crate) async fn new_with_tls_config(
9595
cert_chain: Vec<CertificateDer<'static>>,
9696
server_config: Arc<ServerConfig>,
9797
local: impl ToSocketAddrs,
@@ -111,28 +111,30 @@ impl AttestedTlsServer {
111111
}
112112

113113
/// Accept an incoming connection and handle it in a seperate task
114-
pub async fn accept(&self) -> Result<(), AttestedTlsError> {
114+
pub async fn accept(
115+
&self,
116+
) -> Result<
117+
(
118+
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
119+
Option<MultiMeasurements>,
120+
AttestationType,
121+
),
122+
AttestedTlsError,
123+
> {
115124
let (inbound, _client_addr) = self.listener.accept().await?;
116125

117126
let acceptor = self.acceptor.clone();
118127
let cert_chain = self.cert_chain.clone();
119128
let attestation_generator = self.attestation_generator.clone();
120129
let attestation_verifier = self.attestation_verifier.clone();
121-
tokio::spawn(async move {
122-
if let Err(err) = Self::handle_connection(
123-
inbound,
124-
acceptor,
125-
cert_chain,
126-
attestation_generator,
127-
attestation_verifier,
128-
)
129-
.await
130-
{
131-
warn!("Failed to handle connection: {err}");
132-
}
133-
});
134-
135-
Ok(())
130+
Ok(Self::handle_connection(
131+
inbound,
132+
acceptor,
133+
cert_chain,
134+
attestation_generator,
135+
attestation_verifier,
136+
)
137+
.await?)
136138
}
137139

138140
/// Helper to get the socket address of the underlying TCP listener

src/lib.rs

Lines changed: 32 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ pub mod file_server;
55
pub mod health_check;
66

77
pub use attestation::AttestationGenerator;
8-
use attestation::{measurements::MultiMeasurements, AttestationError, AttestationType};
8+
99
use bytes::Bytes;
1010
use http::HeaderValue;
1111
use http_body_util::{combinators::BoxBody, BodyExt};
@@ -27,22 +27,21 @@ use std::time::Duration;
2727
use std::{net::SocketAddr, sync::Arc};
2828
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
2929
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
30-
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
30+
use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName};
3131
use tokio_rustls::rustls::RootCertStore;
3232
use tokio_rustls::{
3333
rustls::{ClientConfig, ServerConfig},
34-
TlsAcceptor, TlsConnector,
34+
TlsConnector,
3535
};
3636

37-
use crate::attestation::{AttestationExchangeMessage, AttestationVerifier};
38-
39-
/// This makes it possible to add breaking protocol changes and provide backwards compatibility.
40-
/// When adding more supported versions, note that ordering is important. ALPN will pick the first
41-
/// protocol which both parties support - so newer supported versions should come first.
42-
pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"];
43-
44-
/// The label used when exporting key material from a TLS session
45-
const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding";
37+
use crate::attestation::{
38+
measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, AttestationType,
39+
AttestationVerifier,
40+
};
41+
use crate::attested_tls::{
42+
AttestedTlsError, AttestedTlsServer, TlsCertAndKey, EXPORTER_LABEL,
43+
SUPPORTED_ALPN_PROTOCOL_VERSIONS,
44+
};
4645

4746
/// The header name for giving attestation type
4847
const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type";
@@ -59,26 +58,9 @@ type RequestWithResponseSender = (
5958
);
6059
type Http2Sender = hyper::client::conn::http2::SendRequest<hyper::body::Incoming>;
6160

62-
/// TLS Credentials
63-
pub struct TlsCertAndKey {
64-
/// Der-encoded TLS certificate chain
65-
pub cert_chain: Vec<CertificateDer<'static>>,
66-
/// Der-encoded TLS private key
67-
pub key: PrivateKeyDer<'static>,
68-
}
69-
7061
/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
7162
pub struct ProxyServer {
72-
/// The underlying TCP listener
73-
listener: TcpListener,
74-
/// Quote generation type to use (including none)
75-
attestation_generator: AttestationGenerator,
76-
/// Verifier for remote attestation (including none)
77-
attestation_verifier: AttestationVerifier,
78-
/// The certificate chain
79-
cert_chain: Vec<CertificateDer<'static>>,
80-
/// For accepting TLS connections
81-
acceptor: TlsAcceptor,
63+
attested_tls_server: AttestedTlsServer,
8264
/// The address of the target service we are proxying to
8365
target: SocketAddr,
8466
}
@@ -133,38 +115,30 @@ impl ProxyServer {
133115
attestation_generator: AttestationGenerator,
134116
attestation_verifier: AttestationVerifier,
135117
) -> Result<Self, ProxyError> {
136-
let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
137-
let listener = TcpListener::bind(local).await?;
138-
139-
Ok(Self {
140-
listener,
118+
let attested_tls_server = AttestedTlsServer::new_with_tls_config(
119+
cert_chain,
120+
server_config,
121+
local,
141122
attestation_generator,
142123
attestation_verifier,
143-
acceptor,
124+
)
125+
.await?;
126+
127+
Ok(Self {
128+
attested_tls_server,
144129
target,
145-
cert_chain,
146130
})
147131
}
148132

149133
/// Accept an incoming connection and handle it in a seperate task
150134
pub async fn accept(&self) -> Result<(), ProxyError> {
151-
let (inbound, _client_addr) = self.listener.accept().await?;
135+
let target = self.target.clone();
136+
let (tls_stream, measurements, attestation_type) =
137+
self.attested_tls_server.accept().await?;
152138

153-
let acceptor = self.acceptor.clone();
154-
let target = self.target;
155-
let cert_chain = self.cert_chain.clone();
156-
let attestation_generator = self.attestation_generator.clone();
157-
let attestation_verifier = self.attestation_verifier.clone();
158139
tokio::spawn(async move {
159-
if let Err(err) = Self::handle_connection(
160-
inbound,
161-
acceptor,
162-
target,
163-
cert_chain,
164-
attestation_generator,
165-
attestation_verifier,
166-
)
167-
.await
140+
if let Err(err) =
141+
Self::handle_connection(tls_stream, measurements, attestation_type, target).await
168142
{
169143
warn!("Failed to handle connection: {err}");
170144
}
@@ -175,74 +149,18 @@ impl ProxyServer {
175149

176150
/// Helper to get the socket address of the underlying TCP listener
177151
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
178-
self.listener.local_addr()
152+
self.attested_tls_server.local_addr()
179153
}
180154

181155
/// Handle an incoming connection from a proxy-client
182156
async fn handle_connection(
183-
inbound: TcpStream,
184-
acceptor: TlsAcceptor,
157+
tls_stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
158+
measurements: Option<MultiMeasurements>,
159+
remote_attestation_type: AttestationType,
185160
target: SocketAddr,
186-
cert_chain: Vec<CertificateDer<'static>>,
187-
attestation_generator: AttestationGenerator,
188-
attestation_verifier: AttestationVerifier,
189161
) -> Result<(), ProxyError> {
190162
tracing::debug!("proxy-server accepted connection");
191163

192-
// Do TLS handshake
193-
let mut tls_stream = acceptor.accept(inbound).await?;
194-
let (_io, connection) = tls_stream.get_ref();
195-
196-
// Ensure that we agreed a protocol
197-
let _negotiated_protocol = connection.alpn_protocol().ok_or(ProxyError::AlpnFailed)?;
198-
199-
// Compute an exporter unique to the session
200-
let mut exporter = [0u8; 32];
201-
connection.export_keying_material(
202-
&mut exporter,
203-
EXPORTER_LABEL,
204-
None, // context
205-
)?;
206-
207-
let input_data = compute_report_input(Some(&cert_chain), exporter)?;
208-
209-
// Get the TLS certficate chain of the client, if there is one
210-
let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned());
211-
212-
// If we are in a CVM, generate an attestation
213-
let attestation = attestation_generator
214-
.generate_attestation(input_data)
215-
.await?
216-
.encode();
217-
218-
// Write our attestation to the channel, with length prefix
219-
let attestation_length_prefix = length_prefix(&attestation);
220-
tls_stream.write_all(&attestation_length_prefix).await?;
221-
tls_stream.write_all(&attestation).await?;
222-
223-
// Now read a length-prefixed attestation from the remote peer
224-
// In the case of no client attestation this will be zero bytes
225-
let mut length_bytes = [0; 4];
226-
tls_stream.read_exact(&mut length_bytes).await?;
227-
let length: usize = u32::from_be_bytes(length_bytes).try_into()?;
228-
229-
let mut buf = vec![0; length];
230-
tls_stream.read_exact(&mut buf).await?;
231-
232-
let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?;
233-
let remote_attestation_type = remote_attestation_message.attestation_type;
234-
235-
// If we expect an attestaion from the client, verify it and get measurements
236-
let measurements = if attestation_verifier.has_remote_attestion() {
237-
let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?;
238-
239-
attestation_verifier
240-
.verify_attestation(remote_attestation_message, remote_input_data)
241-
.await?
242-
} else {
243-
None
244-
};
245-
246164
// Setup an HTTP server
247165
let http = hyper::server::conn::http2::Builder::new(TokioExecutor);
248166

@@ -819,6 +737,8 @@ pub enum ProxyError {
819737
Serialization(#[from] parity_scale_codec::Error),
820738
#[error("Protocol negotiation failed - remote peer does not support this protocol")]
821739
AlpnFailed,
740+
#[error("Attested TLS: {0}")]
741+
AttestedTls(#[from] AttestedTlsError),
822742
}
823743

824744
impl From<mpsc::error::SendError<RequestWithResponseSender>> for ProxyError {

src/main.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ use tracing::level_filters::LevelFilter;
88
use attested_tls_proxy::{
99
attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier},
1010
attested_get::attested_get,
11+
attested_tls::TlsCertAndKey,
1112
file_server::attested_file_server,
12-
get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey,
13+
get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer,
1314
};
1415

1516
#[derive(Parser, Debug, Clone)]

0 commit comments

Comments
 (0)