Skip to content

Commit 2ab2086

Browse files
committed
Switch to http proxy server (rather than raw TCP)
1 parent 0540ba2 commit 2ab2086

File tree

3 files changed

+147
-87
lines changed

3 files changed

+147
-87
lines changed

Cargo.lock

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ configfs-tsm = "0.0.2"
2020
rand_core = { version = "0.6.4", features = ["getrandom"] }
2121
dcap-qvl = "0.3.4"
2222
hex = "0.4.3"
23+
hyper = { version = "1.7.0", features = ["server"] }
24+
hyper-util = "0.1.17"
25+
http-body-util = "0.1.3"
26+
bytes = "1.10.1"
2327

2428
[dev-dependencies]
2529
rcgen = "0.14.5"

src/lib.rs

Lines changed: 139 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ pub use attestation::{
55
DcapTdxQuoteGenerator, DcapTdxQuoteVerifier, NoQuoteGenerator, NoQuoteVerifier, QuoteGenerator,
66
QuoteVerifier,
77
};
8+
use hyper::server::conn::http1::Builder;
9+
use hyper::service::service_fn;
10+
use hyper::Response;
11+
use hyper_util::rt::TokioIo;
812
use thiserror::Error;
913
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
1014

@@ -204,16 +208,52 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
204208
.await?;
205209
}
206210

207-
let outbound = TcpStream::connect(target).await?;
211+
let http = Builder::new();
212+
let service =
213+
service_fn(move |req| async move { Self::handle_http_request(req, target).await });
208214

209-
let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream);
210-
let (mut outbound_reader, mut outbound_writer) = outbound.into_split();
215+
let io = TokioIo::new(tls_stream);
216+
http.serve_connection(io, service).await.unwrap();
211217

212-
let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
213-
let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
214-
tokio::try_join!(client_to_server, server_to_client)?;
218+
// let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream);
219+
// let (mut outbound_reader, mut outbound_writer) = outbound.into_split();
220+
//
221+
// let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
222+
// let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
223+
// tokio::try_join!(client_to_server, server_to_client)?;
215224
Ok(())
216225
}
226+
227+
// Handle a request from the proxy client to the target server
228+
async fn handle_http_request(
229+
req: hyper::Request<hyper::body::Incoming>,
230+
target: SocketAddr,
231+
) -> Result<Response<hyper::body::Incoming>, hyper::Error> {
232+
let outbound = TcpStream::connect(target).await.unwrap();
233+
let outbound_io = TokioIo::new(outbound);
234+
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
235+
.handshake::<_, hyper::body::Incoming>(outbound_io)
236+
.await
237+
.unwrap();
238+
239+
// Drive the connection
240+
tokio::spawn(async move {
241+
if let Err(e) = conn.await {
242+
eprintln!("client conn error: {e}");
243+
}
244+
});
245+
246+
match sender.send_request(req).await {
247+
Ok(resp) => Ok(resp),
248+
Err(e) => {
249+
eprintln!("send_request error: {e}");
250+
// let mut resp = Response::new(hyper::body::Incoming::empty());
251+
// *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
252+
// Ok(resp)
253+
panic!("todo");
254+
}
255+
}
256+
}
217257
}
218258

219259
pub struct ProxyClient<L, R>
@@ -337,58 +377,129 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
337377
local_attestation_platform: L,
338378
remote_attestation_platform: R,
339379
) -> Result<(), ProxyError> {
340-
let out = TcpStream::connect(&target).await?;
380+
let http = Builder::new();
381+
let service = service_fn(move |req| {
382+
let connector = connector.clone();
383+
let target = target.clone();
384+
let cert_chain = cert_chain.clone();
385+
let local_attestation_platform = local_attestation_platform.clone();
386+
let remote_attestation_platform = remote_attestation_platform.clone();
387+
async move {
388+
Self::handle_http_request(
389+
req,
390+
connector,
391+
target,
392+
cert_chain,
393+
local_attestation_platform,
394+
remote_attestation_platform,
395+
)
396+
.await
397+
}
398+
});
399+
400+
let io = TokioIo::new(inbound);
401+
http.serve_connection(io, service).await.unwrap();
402+
403+
// let (mut inbound_reader, mut inbound_writer) = inbound.into_split();
404+
// let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream);
405+
//
406+
// let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
407+
// let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
408+
// tokio::try_join!(client_to_server, server_to_client)?;
409+
Ok(())
410+
}
411+
412+
// Handle a request from the source client to the proxy server
413+
async fn handle_http_request(
414+
req: hyper::Request<hyper::body::Incoming>,
415+
connector: TlsConnector,
416+
target: String,
417+
cert_chain: Option<Vec<CertificateDer<'static>>>,
418+
local_attestation_platform: L,
419+
remote_attestation_platform: R,
420+
) -> Result<Response<hyper::body::Incoming>, hyper::Error> {
421+
let out = TcpStream::connect(&target).await.unwrap();
341422
let mut tls_stream = connector
342-
.connect(server_name_from_host(&target)?, out)
343-
.await?;
423+
.connect(server_name_from_host(&target).unwrap(), out)
424+
.await
425+
.unwrap();
344426

345427
let (_io, server_connection) = tls_stream.get_ref();
346428

347429
let mut exporter = [0u8; 32];
348-
server_connection.export_keying_material(
349-
&mut exporter,
350-
EXPORTER_LABEL,
351-
None, // context
352-
)?;
430+
server_connection
431+
.export_keying_material(
432+
&mut exporter,
433+
EXPORTER_LABEL,
434+
None, // context
435+
)
436+
.unwrap();
353437

354438
let remote_cert_chain = server_connection
355439
.peer_certificates()
356-
.ok_or(ProxyError::NoCertificate)?
440+
.ok_or(ProxyError::NoCertificate)
441+
.unwrap()
357442
.to_owned();
358443

359444
let mut length_bytes = [0; 4];
360-
tls_stream.read_exact(&mut length_bytes).await?;
361-
let length: usize = u32::from_be_bytes(length_bytes).try_into()?;
445+
tls_stream.read_exact(&mut length_bytes).await.unwrap();
446+
let length: usize = u32::from_be_bytes(length_bytes).try_into().unwrap();
362447

363448
let mut buf = vec![0; length];
364-
tls_stream.read_exact(&mut buf).await?;
449+
tls_stream.read_exact(&mut buf).await.unwrap();
365450

366451
if remote_attestation_platform.is_cvm() {
367452
remote_attestation_platform
368453
.verify_attestation(buf, &remote_cert_chain, exporter)
369-
.await?;
454+
.await
455+
.unwrap();
370456
}
371457

372458
let attestation = if local_attestation_platform.is_cvm() {
373459
local_attestation_platform
374-
.create_attestation(&cert_chain.ok_or(ProxyError::NoClientAuth)?, exporter)?
460+
.create_attestation(
461+
&cert_chain.ok_or(ProxyError::NoClientAuth).unwrap(),
462+
exporter,
463+
)
464+
.unwrap()
375465
} else {
376466
Vec::new()
377467
};
378468

379469
let attestation_length_prefix = length_prefix(&attestation);
380470

381-
tls_stream.write_all(&attestation_length_prefix).await?;
471+
tls_stream
472+
.write_all(&attestation_length_prefix)
473+
.await
474+
.unwrap();
382475

383-
tls_stream.write_all(&attestation).await?;
476+
tls_stream.write_all(&attestation).await.unwrap();
384477

385-
let (mut inbound_reader, mut inbound_writer) = inbound.into_split();
386-
let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream);
478+
// Now the attestation is done, forward the connection to the proxy server
479+
// let outbound = TcpStream::connect(target).await.unwrap();
480+
let outbound_io = TokioIo::new(tls_stream);
481+
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
482+
.handshake::<_, hyper::body::Incoming>(outbound_io)
483+
.await
484+
.unwrap();
387485

388-
let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
389-
let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
390-
tokio::try_join!(client_to_server, server_to_client)?;
391-
Ok(())
486+
// Drive the connection
487+
tokio::spawn(async move {
488+
if let Err(e) = conn.await {
489+
eprintln!("client conn error: {e}");
490+
}
491+
});
492+
493+
match sender.send_request(req).await {
494+
Ok(resp) => Ok(resp),
495+
Err(e) => {
496+
eprintln!("send_request error: {e}");
497+
// let mut resp = Response::new(hyper::body::Incoming::empty());
498+
// *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
499+
// Ok(resp)
500+
panic!("todo");
501+
}
502+
}
392503
}
393504
}
394505

@@ -643,65 +754,6 @@ mod tests {
643754
assert_eq!(res, "foobar");
644755
}
645756

646-
#[tokio::test]
647-
async fn raw_tcp_proxy() {
648-
let target_addr = example_service().await;
649-
650-
let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap());
651-
let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key);
652-
653-
let proxy_server = ProxyServer::new_with_tls_config(
654-
cert_chain,
655-
server_config,
656-
"127.0.0.1:0",
657-
target_addr,
658-
DcapTdxQuoteGenerator,
659-
NoQuoteVerifier,
660-
)
661-
.await
662-
.unwrap();
663-
664-
let proxy_server_addr = proxy_server.local_addr().unwrap();
665-
666-
tokio::spawn(async move {
667-
proxy_server.accept().await.unwrap();
668-
});
669-
670-
let quote_verifier = DcapTdxQuoteVerifier {
671-
accepted_platform_measurements: None,
672-
accepted_cvm_image_measurements: vec![CvmImageMeasurements {
673-
rtmr1: [0u8; 48],
674-
rtmr2: [0u8; 48],
675-
rtmr3: [0u8; 48],
676-
}],
677-
pccs_url: None,
678-
};
679-
680-
let proxy_client = ProxyClient::new_with_tls_config(
681-
client_config,
682-
"127.0.0.1:0",
683-
proxy_server_addr.to_string(),
684-
NoQuoteGenerator,
685-
quote_verifier,
686-
None,
687-
)
688-
.await
689-
.unwrap();
690-
691-
let proxy_client_addr = proxy_client.local_addr().unwrap();
692-
693-
tokio::spawn(async move {
694-
proxy_client.accept().await.unwrap();
695-
});
696-
697-
let mut out = TcpStream::connect(proxy_client_addr).await.unwrap();
698-
699-
let mut buf = [0; 9];
700-
out.read(&mut buf).await.unwrap();
701-
702-
assert_eq!(buf[..], b"some data"[..]);
703-
}
704-
705757
#[tokio::test]
706758
async fn test_get_tls_cert() {
707759
let target_addr = example_service().await;

0 commit comments

Comments
 (0)