@@ -5,6 +5,10 @@ pub use attestation::{
55 DcapTdxQuoteGenerator , DcapTdxQuoteVerifier , NoQuoteGenerator , NoQuoteVerifier , QuoteGenerator ,
66 QuoteVerifier ,
77} ;
8+ use hyper:: server:: conn:: http1:: Builder ;
9+ use hyper:: service:: service_fn;
10+ use hyper:: Response ;
11+ use hyper_util:: rt:: TokioIo ;
812use thiserror:: Error ;
913use tokio_rustls:: rustls:: server:: { VerifierBuilderError , WebPkiClientVerifier } ;
1014
@@ -204,16 +208,52 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyServer<L, R> {
204208 . await ?;
205209 }
206210
207- let outbound = TcpStream :: connect ( target) . await ?;
211+ let http = Builder :: new ( ) ;
212+ let service =
213+ service_fn ( move |req| async move { Self :: handle_http_request ( req, target) . await } ) ;
208214
209- let ( mut inbound_reader , mut inbound_writer ) = tokio :: io :: split ( tls_stream) ;
210- let ( mut outbound_reader , mut outbound_writer ) = outbound . into_split ( ) ;
215+ let io = TokioIo :: new ( tls_stream) ;
216+ http . serve_connection ( io , service ) . await . unwrap ( ) ;
211217
212- let client_to_server = tokio:: io:: copy ( & mut inbound_reader, & mut outbound_writer) ;
213- let server_to_client = tokio:: io:: copy ( & mut outbound_reader, & mut inbound_writer) ;
214- tokio:: try_join!( client_to_server, server_to_client) ?;
218+ // let (mut inbound_reader, mut inbound_writer) = tokio::io::split(tls_stream);
219+ // let (mut outbound_reader, mut outbound_writer) = outbound.into_split();
220+ //
221+ // let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
222+ // let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
223+ // tokio::try_join!(client_to_server, server_to_client)?;
215224 Ok ( ( ) )
216225 }
226+
227+ // Handle a request from the proxy client to the target server
228+ async fn handle_http_request (
229+ req : hyper:: Request < hyper:: body:: Incoming > ,
230+ target : SocketAddr ,
231+ ) -> Result < Response < hyper:: body:: Incoming > , hyper:: Error > {
232+ let outbound = TcpStream :: connect ( target) . await . unwrap ( ) ;
233+ let outbound_io = TokioIo :: new ( outbound) ;
234+ let ( mut sender, conn) = hyper:: client:: conn:: http1:: Builder :: new ( )
235+ . handshake :: < _ , hyper:: body:: Incoming > ( outbound_io)
236+ . await
237+ . unwrap ( ) ;
238+
239+ // Drive the connection
240+ tokio:: spawn ( async move {
241+ if let Err ( e) = conn. await {
242+ eprintln ! ( "client conn error: {e}" ) ;
243+ }
244+ } ) ;
245+
246+ match sender. send_request ( req) . await {
247+ Ok ( resp) => Ok ( resp) ,
248+ Err ( e) => {
249+ eprintln ! ( "send_request error: {e}" ) ;
250+ // let mut resp = Response::new(hyper::body::Incoming::empty());
251+ // *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
252+ // Ok(resp)
253+ panic ! ( "todo" ) ;
254+ }
255+ }
256+ }
217257}
218258
219259pub struct ProxyClient < L , R >
@@ -337,58 +377,129 @@ impl<L: QuoteGenerator, R: QuoteVerifier> ProxyClient<L, R> {
337377 local_attestation_platform : L ,
338378 remote_attestation_platform : R ,
339379 ) -> Result < ( ) , ProxyError > {
340- let out = TcpStream :: connect ( & target) . await ?;
380+ let http = Builder :: new ( ) ;
381+ let service = service_fn ( move |req| {
382+ let connector = connector. clone ( ) ;
383+ let target = target. clone ( ) ;
384+ let cert_chain = cert_chain. clone ( ) ;
385+ let local_attestation_platform = local_attestation_platform. clone ( ) ;
386+ let remote_attestation_platform = remote_attestation_platform. clone ( ) ;
387+ async move {
388+ Self :: handle_http_request (
389+ req,
390+ connector,
391+ target,
392+ cert_chain,
393+ local_attestation_platform,
394+ remote_attestation_platform,
395+ )
396+ . await
397+ }
398+ } ) ;
399+
400+ let io = TokioIo :: new ( inbound) ;
401+ http. serve_connection ( io, service) . await . unwrap ( ) ;
402+
403+ // let (mut inbound_reader, mut inbound_writer) = inbound.into_split();
404+ // let (mut outbound_reader, mut outbound_writer) = tokio::io::split(tls_stream);
405+ //
406+ // let client_to_server = tokio::io::copy(&mut inbound_reader, &mut outbound_writer);
407+ // let server_to_client = tokio::io::copy(&mut outbound_reader, &mut inbound_writer);
408+ // tokio::try_join!(client_to_server, server_to_client)?;
409+ Ok ( ( ) )
410+ }
411+
412+ // Handle a request from the source client to the proxy server
413+ async fn handle_http_request (
414+ req : hyper:: Request < hyper:: body:: Incoming > ,
415+ connector : TlsConnector ,
416+ target : String ,
417+ cert_chain : Option < Vec < CertificateDer < ' static > > > ,
418+ local_attestation_platform : L ,
419+ remote_attestation_platform : R ,
420+ ) -> Result < Response < hyper:: body:: Incoming > , hyper:: Error > {
421+ let out = TcpStream :: connect ( & target) . await . unwrap ( ) ;
341422 let mut tls_stream = connector
342- . connect ( server_name_from_host ( & target) ?, out)
343- . await ?;
423+ . connect ( server_name_from_host ( & target) . unwrap ( ) , out)
424+ . await
425+ . unwrap ( ) ;
344426
345427 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
346428
347429 let mut exporter = [ 0u8 ; 32 ] ;
348- server_connection. export_keying_material (
349- & mut exporter,
350- EXPORTER_LABEL ,
351- None , // context
352- ) ?;
430+ server_connection
431+ . export_keying_material (
432+ & mut exporter,
433+ EXPORTER_LABEL ,
434+ None , // context
435+ )
436+ . unwrap ( ) ;
353437
354438 let remote_cert_chain = server_connection
355439 . peer_certificates ( )
356- . ok_or ( ProxyError :: NoCertificate ) ?
440+ . ok_or ( ProxyError :: NoCertificate )
441+ . unwrap ( )
357442 . to_owned ( ) ;
358443
359444 let mut length_bytes = [ 0 ; 4 ] ;
360- tls_stream. read_exact ( & mut length_bytes) . await ? ;
361- let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) ? ;
445+ tls_stream. read_exact ( & mut length_bytes) . await . unwrap ( ) ;
446+ let length: usize = u32:: from_be_bytes ( length_bytes) . try_into ( ) . unwrap ( ) ;
362447
363448 let mut buf = vec ! [ 0 ; length] ;
364- tls_stream. read_exact ( & mut buf) . await ? ;
449+ tls_stream. read_exact ( & mut buf) . await . unwrap ( ) ;
365450
366451 if remote_attestation_platform. is_cvm ( ) {
367452 remote_attestation_platform
368453 . verify_attestation ( buf, & remote_cert_chain, exporter)
369- . await ?;
454+ . await
455+ . unwrap ( ) ;
370456 }
371457
372458 let attestation = if local_attestation_platform. is_cvm ( ) {
373459 local_attestation_platform
374- . create_attestation ( & cert_chain. ok_or ( ProxyError :: NoClientAuth ) ?, exporter) ?
460+ . create_attestation (
461+ & cert_chain. ok_or ( ProxyError :: NoClientAuth ) . unwrap ( ) ,
462+ exporter,
463+ )
464+ . unwrap ( )
375465 } else {
376466 Vec :: new ( )
377467 } ;
378468
379469 let attestation_length_prefix = length_prefix ( & attestation) ;
380470
381- tls_stream. write_all ( & attestation_length_prefix) . await ?;
471+ tls_stream
472+ . write_all ( & attestation_length_prefix)
473+ . await
474+ . unwrap ( ) ;
382475
383- tls_stream. write_all ( & attestation) . await ? ;
476+ tls_stream. write_all ( & attestation) . await . unwrap ( ) ;
384477
385- let ( mut inbound_reader, mut inbound_writer) = inbound. into_split ( ) ;
386- let ( mut outbound_reader, mut outbound_writer) = tokio:: io:: split ( tls_stream) ;
478+ // Now the attestation is done, forward the connection to the proxy server
479+ // let outbound = TcpStream::connect(target).await.unwrap();
480+ let outbound_io = TokioIo :: new ( tls_stream) ;
481+ let ( mut sender, conn) = hyper:: client:: conn:: http1:: Builder :: new ( )
482+ . handshake :: < _ , hyper:: body:: Incoming > ( outbound_io)
483+ . await
484+ . unwrap ( ) ;
387485
388- let client_to_server = tokio:: io:: copy ( & mut inbound_reader, & mut outbound_writer) ;
389- let server_to_client = tokio:: io:: copy ( & mut outbound_reader, & mut inbound_writer) ;
390- tokio:: try_join!( client_to_server, server_to_client) ?;
391- Ok ( ( ) )
486+ // Drive the connection
487+ tokio:: spawn ( async move {
488+ if let Err ( e) = conn. await {
489+ eprintln ! ( "client conn error: {e}" ) ;
490+ }
491+ } ) ;
492+
493+ match sender. send_request ( req) . await {
494+ Ok ( resp) => Ok ( resp) ,
495+ Err ( e) => {
496+ eprintln ! ( "send_request error: {e}" ) ;
497+ // let mut resp = Response::new(hyper::body::Incoming::empty());
498+ // *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
499+ // Ok(resp)
500+ panic ! ( "todo" ) ;
501+ }
502+ }
392503 }
393504}
394505
@@ -643,65 +754,6 @@ mod tests {
643754 assert_eq ! ( res, "foobar" ) ;
644755 }
645756
646- #[ tokio:: test]
647- async fn raw_tcp_proxy ( ) {
648- let target_addr = example_service ( ) . await ;
649-
650- let ( cert_chain, private_key) = generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
651- let ( server_config, client_config) = generate_tls_config ( cert_chain. clone ( ) , private_key) ;
652-
653- let proxy_server = ProxyServer :: new_with_tls_config (
654- cert_chain,
655- server_config,
656- "127.0.0.1:0" ,
657- target_addr,
658- DcapTdxQuoteGenerator ,
659- NoQuoteVerifier ,
660- )
661- . await
662- . unwrap ( ) ;
663-
664- let proxy_server_addr = proxy_server. local_addr ( ) . unwrap ( ) ;
665-
666- tokio:: spawn ( async move {
667- proxy_server. accept ( ) . await . unwrap ( ) ;
668- } ) ;
669-
670- let quote_verifier = DcapTdxQuoteVerifier {
671- accepted_platform_measurements : None ,
672- accepted_cvm_image_measurements : vec ! [ CvmImageMeasurements {
673- rtmr1: [ 0u8 ; 48 ] ,
674- rtmr2: [ 0u8 ; 48 ] ,
675- rtmr3: [ 0u8 ; 48 ] ,
676- } ] ,
677- pccs_url : None ,
678- } ;
679-
680- let proxy_client = ProxyClient :: new_with_tls_config (
681- client_config,
682- "127.0.0.1:0" ,
683- proxy_server_addr. to_string ( ) ,
684- NoQuoteGenerator ,
685- quote_verifier,
686- None ,
687- )
688- . await
689- . unwrap ( ) ;
690-
691- let proxy_client_addr = proxy_client. local_addr ( ) . unwrap ( ) ;
692-
693- tokio:: spawn ( async move {
694- proxy_client. accept ( ) . await . unwrap ( ) ;
695- } ) ;
696-
697- let mut out = TcpStream :: connect ( proxy_client_addr) . await . unwrap ( ) ;
698-
699- let mut buf = [ 0 ; 9 ] ;
700- out. read ( & mut buf) . await . unwrap ( ) ;
701-
702- assert_eq ! ( buf[ ..] , b"some data" [ ..] ) ;
703- }
704-
705757 #[ tokio:: test]
706758 async fn test_get_tls_cert ( ) {
707759 let target_addr = example_service ( ) . await ;
0 commit comments