Skip to content

Commit 809fe09

Browse files
committed
Error handling
1 parent 156386e commit 809fe09

File tree

2 files changed

+144
-101
lines changed

2 files changed

+144
-101
lines changed

src/lib.rs

Lines changed: 142 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
mod attestation;
22

3+
use attestation::AttestationError;
34
pub use attestation::{AttestationPlatform, MockAttestation, NoAttestation};
45
use thiserror::Error;
56
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
67

78
#[cfg(test)]
89
mod test_helpers;
910

11+
use std::num::TryFromIntError;
1012
use std::{net::SocketAddr, sync::Arc};
1113
use tokio::io::{self, AsyncReadExt, AsyncWriteExt};
1214
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
@@ -129,58 +131,18 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
129131
let local_attestation_platform = self.inner.local_attestation_platform.clone();
130132
let remote_attestation_platform = self.inner.remote_attestation_platform.clone();
131133
tokio::spawn(async move {
132-
let mut tls_stream = acceptor.accept(inbound).await.unwrap();
133-
let (_io, connection) = tls_stream.get_ref();
134-
135-
let mut exporter = [0u8; 32];
136-
connection
137-
.export_keying_material(
138-
&mut exporter,
139-
EXPORTER_LABEL,
140-
None, // context
141-
)
142-
.unwrap();
143-
144-
let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned());
145-
146-
let attestation = if local_attestation_platform.is_cvm() {
147-
local_attestation_platform
148-
.create_attestation(&cert_chain, exporter)
149-
.unwrap()
150-
} else {
151-
Vec::new()
152-
};
153-
154-
let attestation_length_prefix = length_prefix(&attestation);
155-
156-
tls_stream
157-
.write_all(&attestation_length_prefix)
158-
.await
159-
.unwrap();
160-
161-
tls_stream.write_all(&attestation).await.unwrap();
162-
163-
let mut length_bytes = [0; 4];
164-
tls_stream.read_exact(&mut length_bytes).await.unwrap();
165-
let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap();
166-
167-
let mut buf = vec![0; length];
168-
tls_stream.read_exact(&mut buf).await.unwrap();
169-
170-
if remote_attestation_platform.is_cvm() {
171-
remote_attestation_platform
172-
.verify_attestation(buf, &remote_cert_chain.unwrap(), exporter)
173-
.unwrap();
134+
if let Err(err) = Self::handle_connection(
135+
inbound,
136+
acceptor,
137+
target,
138+
cert_chain,
139+
local_attestation_platform,
140+
remote_attestation_platform,
141+
)
142+
.await
143+
{
144+
eprintln!("Failed to handle connection: {err}");
174145
}
175-
176-
let outbound = TcpStream::connect(target).await.unwrap();
177-
178-
let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream);
179-
let (mut outbound_reader, mut outbound_writer) = outbound.into_split();
180-
181-
let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
182-
let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
183-
tokio::try_join!(client_to_server, server_to_client).unwrap();
184146
});
185147

186148
Ok(())
@@ -189,6 +151,64 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
189151
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
190152
self.inner.listener.local_addr()
191153
}
154+
155+
async fn handle_connection(
156+
inbound: TcpStream,
157+
acceptor: TlsAcceptor,
158+
target: SocketAddr,
159+
cert_chain: Vec<CertificateDer<'static>>,
160+
local_attestation_platform: L,
161+
remote_attestation_platform: R,
162+
) -> Result<(), ProxyError> {
163+
let mut tls_stream = acceptor.accept(inbound).await?;
164+
let (_io, connection) = tls_stream.get_ref();
165+
166+
let mut exporter = [0u8; 32];
167+
connection.export_keying_material(
168+
&mut exporter,
169+
EXPORTER_LABEL,
170+
None, // context
171+
)?;
172+
173+
let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned());
174+
175+
let attestation = if local_attestation_platform.is_cvm() {
176+
local_attestation_platform.create_attestation(&cert_chain, exporter)?
177+
} else {
178+
Vec::new()
179+
};
180+
181+
let attestation_length_prefix = length_prefix(&attestation);
182+
183+
tls_stream.write_all(&attestation_length_prefix).await?;
184+
185+
tls_stream.write_all(&attestation).await?;
186+
187+
let mut length_bytes = [0; 4];
188+
tls_stream.read_exact(&mut length_bytes).await?;
189+
let length: usize = u32::from_be_bytes(length_bytes).try_into()?;
190+
191+
let mut buf = vec![0; length];
192+
tls_stream.read_exact(&mut buf).await?;
193+
194+
if remote_attestation_platform.is_cvm() {
195+
remote_attestation_platform.verify_attestation(
196+
buf,
197+
&remote_cert_chain.ok_or(ProxyError::NoClientAuth)?,
198+
exporter,
199+
)?;
200+
}
201+
202+
let outbound = TcpStream::connect(target).await?;
203+
204+
let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream);
205+
let (mut outbound_reader, mut outbound_writer) = outbound.into_split();
206+
207+
let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
208+
let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
209+
tokio::try_join!(client_to_server, server_to_client)?;
210+
Ok(())
211+
}
192212
}
193213

194214
pub struct ProxyClient<L, R>
@@ -284,78 +304,104 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
284304
let cert_chain = self.cert_chain.clone();
285305

286306
tokio::spawn(async move {
287-
let out = TcpStream::connect(target).await.unwrap();
288-
let mut tls_stream = connector.connect(target_name, out).await.unwrap();
307+
if let Err(err) = Self::handle_connection(
308+
inbound,
309+
connector,
310+
target,
311+
target_name,
312+
cert_chain,
313+
local_attestation_platform,
314+
remote_attestation_platform,
315+
)
316+
.await
317+
{
318+
eprintln!("Failed to handle connection: {err}");
319+
}
320+
});
289321

290-
let (_io, server_connection) = tls_stream.get_ref();
322+
Ok(())
323+
}
291324

292-
let mut exporter = [0u8; 32];
293-
server_connection
294-
.export_keying_material(
295-
&mut exporter,
296-
EXPORTER_LABEL,
297-
None, // context
298-
)
299-
.unwrap();
325+
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
326+
self.inner.listener.local_addr()
327+
}
300328

301-
let remote_cert_chain = server_connection.peer_certificates().unwrap().to_owned();
329+
async fn handle_connection(
330+
inbound: TcpStream,
331+
connector: TlsConnector,
332+
target: SocketAddr,
333+
target_name: ServerName<'static>,
334+
cert_chain: Option<Vec<CertificateDer<'static>>>,
335+
local_attestation_platform: L,
336+
remote_attestation_platform: R,
337+
) -> Result<(), ProxyError> {
338+
let out = TcpStream::connect(target).await?;
339+
let mut tls_stream = connector.connect(target_name, out).await?;
302340

303-
let mut length_bytes = [0; 4];
304-
tls_stream.read_exact(&mut length_bytes).await.unwrap();
305-
let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap();
341+
let (_io, server_connection) = tls_stream.get_ref();
306342

307-
let mut buf = vec![0; length];
308-
tls_stream.read_exact(&mut buf).await.unwrap();
343+
let mut exporter = [0u8; 32];
344+
server_connection.export_keying_material(
345+
&mut exporter,
346+
EXPORTER_LABEL,
347+
None, // context
348+
)?;
309349

310-
if remote_attestation_platform.is_cvm() {
311-
remote_attestation_platform
312-
.verify_attestation(buf, &remote_cert_chain, exporter)
313-
.unwrap();
314-
}
350+
let remote_cert_chain = server_connection
351+
.peer_certificates()
352+
.ok_or(ProxyError::NoCertificate)?
353+
.to_owned();
315354

316-
let attestation = if local_attestation_platform.is_cvm() {
317-
local_attestation_platform
318-
.create_attestation(&cert_chain.unwrap(), exporter)
319-
.unwrap()
320-
} else {
321-
Vec::new()
322-
};
355+
let mut length_bytes = [0; 4];
356+
tls_stream.read_exact(&mut length_bytes).await?;
357+
let length: usize = u32::from_be_bytes(length_bytes).try_into()?;
323358

324-
let attestation_length_prefix = length_prefix(&attestation);
359+
let mut buf = vec![0; length];
360+
tls_stream.read_exact(&mut buf).await?;
325361

326-
tls_stream
327-
.write_all(&attestation_length_prefix)
328-
.await
329-
.unwrap();
362+
if remote_attestation_platform.is_cvm() {
363+
remote_attestation_platform.verify_attestation(buf, &remote_cert_chain, exporter)?;
364+
}
330365

331-
tls_stream.write_all(&attestation).await.unwrap();
366+
let attestation = if local_attestation_platform.is_cvm() {
367+
local_attestation_platform
368+
.create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)?
369+
} else {
370+
Vec::new()
371+
};
332372

333-
let (mut inbound_reader, mut inbound_writer) = inbound.into_split();
334-
let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream);
373+
let attestation_length_prefix = length_prefix(&attestation);
335374

336-
let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
337-
let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
338-
tokio::try_join!(client_to_server, server_to_client).unwrap();
339-
});
375+
tls_stream.write_all(&attestation_length_prefix).await?;
340376

341-
Ok(())
342-
}
377+
tls_stream.write_all(&attestation).await?;
343378

344-
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
345-
self.inner.listener.local_addr()
379+
let (mut inbound_reader, mut inbound_writer) = inbound.into_split();
380+
let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream);
381+
382+
let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
383+
let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
384+
tokio::try_join!(client_to_server, server_to_client)?;
385+
Ok(())
346386
}
347387
}
348388

349389
#[derive(Error, Debug)]
350390
pub enum ProxyError {
351391
#[error("Client auth is required when the client is running in a CVM")]
352392
NoClientAuth,
393+
#[error("Failed to get server ceritifcate")]
394+
NoCertificate,
353395
#[error("TLS: {0}")]
354396
Rustls(#[from] tokio_rustls::rustls::Error),
355397
#[error("Verifier builder: {0}")]
356398
VerifierBuilder(#[from] VerifierBuilderError),
357399
#[error("IO: {0}")]
358400
Io(#[from] std::io::Error),
401+
#[error("Attestation: {0}")]
402+
Attestation(#[from] AttestationError),
403+
#[error("Integer conversion: {0}")]
404+
IntConversion(#[from] TryFromIntError),
359405
}
360406

361407
fn length_prefix(input: &[u8]) -> [u8; 4] {

src/main.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,8 @@ fn load_tls_cert_and_key(
125125
}
126126

127127
pub fn load_certs_pem(path: PathBuf) -> std::io::Result<Vec<CertificateDer<'static>>> {
128-
Ok(
129-
rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?))
130-
.map(|res| res.unwrap()) //TODO
131-
.collect(),
132-
)
128+
rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?))
129+
.collect::<Result<Vec<_>, _>>()
133130
}
134131

135132
pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result<PrivateKeyDer<'static>> {

0 commit comments

Comments
 (0)