Skip to content

Commit e8ab377

Browse files
committed
Error handling
1 parent 0fe56d5 commit e8ab377

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

src/lib.rs

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,20 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
212212
}
213213

214214
let http = Builder::new();
215-
let service =
216-
service_fn(move |req| async move { Self::handle_http_request(req, target).await });
215+
let service = service_fn(move |req| async move {
216+
match Self::handle_http_request(req, target).await {
217+
Ok(res) => Ok::<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error>(res),
218+
Err(e) => {
219+
eprintln!("send_request error: {e}");
220+
let mut resp = Response::new(full(format!("Request failed: {e}")));
221+
*resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
222+
Ok(resp)
223+
}
224+
}
225+
});
217226

218227
let io = TokioIo::new(tls_stream);
219-
http.serve_connection(io, service).await.unwrap();
228+
http.serve_connection(io, service).await?;
220229

221230
Ok(())
222231
}
@@ -225,14 +234,12 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
225234
async fn handle_http_request(
226235
req: hyper::Request<hyper::body::Incoming>,
227236
target: SocketAddr,
228-
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error> {
229-
let outbound = TcpStream::connect(target).await.unwrap();
237+
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, ProxyError> {
238+
let outbound = TcpStream::connect(target).await?;
230239
let outbound_io = TokioIo::new(outbound);
231240
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
232241
.handshake::<_, hyper::body::Incoming>(outbound_io)
233-
.await
234-
.unwrap();
235-
242+
.await?;
236243
// Drive the connection
237244
tokio::spawn(async move {
238245
if let Err(e) = conn.await {
@@ -387,7 +394,7 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
387394
let local_attestation_platform = local_attestation_platform.clone();
388395
let remote_attestation_platform = remote_attestation_platform.clone();
389396
async move {
390-
Self::handle_http_request(
397+
match Self::handle_http_request(
391398
req,
392399
connector,
393400
target,
@@ -396,11 +403,22 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
396403
remote_attestation_platform,
397404
)
398405
.await
406+
{
407+
Ok(res) => {
408+
Ok::<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error>(res)
409+
}
410+
Err(e) => {
411+
eprintln!("send_request error: {e}");
412+
let mut resp = Response::new(full(format!("Request failed: {e}")));
413+
*resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
414+
Ok(resp)
415+
}
416+
}
399417
}
400418
});
401419

402420
let io = TokioIo::new(inbound);
403-
http.serve_connection(io, service).await.unwrap();
421+
http.serve_connection(io, service).await?;
404422

405423
Ok(())
406424
}
@@ -468,23 +486,21 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
468486
cert_chain: Option<Vec<CertificateDer<'static>>>,
469487
local_attestation_platform: L,
470488
remote_attestation_platform: R,
471-
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error> {
489+
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, ProxyError> {
472490
let tls_stream = Self::setup_connection(
473491
connector,
474492
target,
475493
cert_chain,
476494
local_attestation_platform,
477495
remote_attestation_platform,
478496
)
479-
.await
480-
.unwrap();
497+
.await?;
481498

482499
// Now the attestation is done, forward the request to the proxy server
483500
let outbound_io = TokioIo::new(tls_stream);
484501
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
485502
.handshake::<_, hyper::body::Incoming>(outbound_io)
486-
.await
487-
.unwrap();
503+
.await?;
488504

489505
// Drive the connection
490506
tokio::spawn(async move {
@@ -583,6 +599,8 @@ pub enum ProxyError {
583599
IntConversion(#[from] TryFromIntError),
584600
#[error("Bad host name: {0}")]
585601
BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError),
602+
#[error("HTTP: {0}")]
603+
Hyper(#[from] hyper::Error),
586604
}
587605

588606
/// Given a byte array, encode its length as a 4 byte big endian u32

0 commit comments

Comments
 (0)