Skip to content

Commit 69cf7dc

Browse files
committed
Error handling for connections
1 parent 2ab2086 commit 69cf7dc

File tree

1 file changed

+20
-26
lines changed

1 file changed

+20
-26
lines changed

src/lib.rs

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ pub use attestation::{
55
DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator,
66
QuoteVerifier,
77
};
8+
use bytes::Bytes;
9+
use http_body_util::combinators::BoxBody;
10+
use http_body_util::BodyExt;
811
use hyper::server::conn::http1::Builder;
912
use hyper::service::service_fn;
1013
use hyper::Response;
@@ -215,20 +218,14 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
215218
let io = TokioIo::new(tls_stream);
216219
http.serve_connection(io, service).await.unwrap();
217220

218-
// let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream);
219-
// let (mut outbound_reader, mut outbound_writer) = outbound.into_split();
220-
//
221-
// let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
222-
// let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
223-
// tokio::try_join!(client_to_server, server_to_client)?;
224221
Ok(())
225222
}
226223

227224
// Handle a request from the proxy client to the target server
228225
async fn handle_http_request(
229226
req: hyper::Request<hyper::body::Incoming>,
230227
target: SocketAddr,
231-
) -> Result<Response<hyper::body::Incoming>, hyper::Error> {
228+
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error> {
232229
let outbound = TcpStream::connect(target).await.unwrap();
233230
let outbound_io = TokioIo::new(outbound);
234231
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
@@ -244,18 +241,23 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
244241
});
245242

246243
match sender.send_request(req).await {
247-
Ok(resp) => Ok(resp),
244+
Ok(resp) => Ok(resp.map(|b| b.boxed())),
248245
Err(e) => {
249246
eprintln!("send_request error: {e}");
250-
// let mut resp = Response::new(hyper::body::Incoming::empty());
251-
// *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
252-
// Ok(resp)
253-
panic!("todo");
247+
let mut resp = Response::new(full(format!("Request failed: {e}")));
248+
*resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
249+
Ok(resp)
254250
}
255251
}
256252
}
257253
}
258254

255+
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
256+
http_body_util::Full::new(chunk.into())
257+
.map_err(|never| match never {})
258+
.boxed()
259+
}
260+
259261
pub struct ProxyClient<L, R>
260262
where
261263
L: QuoteGenerator,
@@ -400,12 +402,6 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
400402
let io = TokioIo::new(inbound);
401403
http.serve_connection(io, service).await.unwrap();
402404

403-
// let (mut inbound_reader, mut inbound_writer) = inbound.into_split();
404-
// let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream);
405-
//
406-
// let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
407-
// let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
408-
// tokio::try_join!(client_to_server, server_to_client)?;
409405
Ok(())
410406
}
411407

@@ -417,7 +413,7 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
417413
cert_chain: Option<Vec<CertificateDer<'static>>>,
418414
local_attestation_platform: L,
419415
remote_attestation_platform: R,
420-
) -> Result<Response<hyper::body::Incoming>, hyper::Error> {
416+
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, hyper::Error> {
421417
let out = TcpStream::connect(&target).await.unwrap();
422418
let mut tls_stream = connector
423419
.connect(server_name_from_host(&target).unwrap(), out)
@@ -475,8 +471,7 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
475471

476472
tls_stream.write_all(&attestation).await.unwrap();
477473

478-
// Now the attestation is done, forward the connection to the proxy server
479-
// let outbound = TcpStream::connect(target).await.unwrap();
474+
// Now the attestation is done, forward the request to the proxy server
480475
let outbound_io = TokioIo::new(tls_stream);
481476
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
482477
.handshake::<_, hyper::body::Incoming>(outbound_io)
@@ -491,13 +486,12 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
491486
});
492487

493488
match sender.send_request(req).await {
494-
Ok(resp) => Ok(resp),
489+
Ok(resp) => Ok(resp.map(|b| b.boxed())),
495490
Err(e) => {
496491
eprintln!("send_request error: {e}");
497-
// let mut resp = Response::new(hyper::body::Incoming::empty());
498-
// *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
499-
// Ok(resp)
500-
panic!("todo");
492+
let mut resp = Response::new(full(format!("Request failed: {e}")));
493+
*resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
494+
Ok(resp)
501495
}
502496
}
503497
}

0 commit comments

Comments
 (0)