@@ -218,10 +218,8 @@ where
218218{
219219 inner : Proxy < L , R > ,
220220 connector : TlsConnector ,
221- /// The address of the proxy server
222- target : SocketAddr ,
223- /// The subject name of the proxy server
224- target_name : ServerName < ' static > ,
221+ /// The host and port of the proxy server
222+ target : String ,
225223 /// Certificate chain for client auth
226224 cert_chain : Option < Vec < CertificateDer < ' static > > > ,
227225}
@@ -230,8 +228,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
230228 pub async fn new (
231229 cert_and_key : Option < TlsCertAndKey > ,
232230 address : impl ToSocketAddrs ,
233- server_address : SocketAddr ,
234- server_name : ServerName < ' static > ,
231+ server_name : String ,
235232 local_attestation_platform : L ,
236233 remote_attestation_platform : R ,
237234 ) -> Result < Self , ProxyError > {
@@ -257,7 +254,6 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
257254 Self :: new_with_tls_config (
258255 client_config. into ( ) ,
259256 address,
260- server_address,
261257 server_name,
262258 local_attestation_platform,
263259 remote_attestation_platform,
@@ -272,8 +268,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
272268 async fn new_with_tls_config (
273269 client_config : Arc < ClientConfig > ,
274270 local : impl ToSocketAddrs ,
275- target : SocketAddr ,
276- target_name : ServerName < ' static > ,
271+ target_name : String ,
277272 local_attestation_platform : L ,
278273 remote_attestation_platform : R ,
279274 cert_chain : Option < Vec < CertificateDer < ' static > > > ,
@@ -290,8 +285,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
290285 Ok ( Self {
291286 inner,
292287 connector,
293- target,
294- target_name,
288+ target : host_to_host_with_port ( & target_name) ,
295289 cert_chain,
296290 } )
297291 }
@@ -301,8 +295,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
301295 let ( inbound, _client_addr) = self . inner . listener . accept ( ) . await ?;
302296
303297 let connector = self . connector . clone ( ) ;
304- let target_name = self . target_name . clone ( ) ;
305- let target = self . target ;
298+ let target = self . target . clone ( ) ;
306299 let local_attestation_platform = self . inner . local_attestation_platform . clone ( ) ;
307300 let remote_attestation_platform = self . inner . remote_attestation_platform . clone ( ) ;
308301 let cert_chain = self . cert_chain . clone ( ) ;
@@ -312,7 +305,6 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
312305 inbound,
313306 connector,
314307 target,
315- target_name,
316308 cert_chain,
317309 local_attestation_platform,
318310 remote_attestation_platform,
@@ -335,14 +327,15 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
335327 async fn handle_connection (
336328 inbound : TcpStream ,
337329 connector : TlsConnector ,
338- target : SocketAddr ,
339- target_name : ServerName < ' static > ,
330+ target : String ,
340331 cert_chain : Option < Vec < CertificateDer < ' static > > > ,
341332 local_attestation_platform : L ,
342333 remote_attestation_platform : R ,
343334 ) -> Result < ( ) , ProxyError > {
344- let out = TcpStream :: connect ( target) . await ?;
345- let mut tls_stream = connector. connect ( target_name, out) . await ?;
335+ let out = TcpStream :: connect ( & target) . await ?;
336+ let mut tls_stream = connector
337+ . connect ( server_name_from_host ( & target) ?, out)
338+ . await ?;
346339
347340 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
348341
@@ -394,16 +387,14 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
394387
395388/// Just get the attested remote certificate, with no client authentication
396389pub async fn get_tls_cert < R : AttestationPlatform > (
397- server_address : SocketAddr ,
398- server_name : ServerName < ' static > ,
390+ server_name : String ,
399391 remote_attestation_platform : R ,
400392) -> Result < Vec < CertificateDer < ' static > > , ProxyError > {
401393 let root_store = RootCertStore :: from_iter ( webpki_roots:: TLS_SERVER_ROOTS . iter ( ) . cloned ( ) ) ;
402394 let client_config = ClientConfig :: builder ( )
403395 . with_root_certificates ( root_store)
404396 . with_no_client_auth ( ) ;
405397 get_tls_cert_with_config (
406- server_address,
407398 server_name,
408399 remote_attestation_platform,
409400 client_config. into ( ) ,
@@ -412,15 +403,16 @@ pub async fn get_tls_cert<R: AttestationPlatform>(
412403}
413404
414405async fn get_tls_cert_with_config < R : AttestationPlatform > (
415- server_address : SocketAddr ,
416- server_name : ServerName < ' static > ,
406+ server_name : String ,
417407 remote_attestation_platform : R ,
418408 client_config : Arc < ClientConfig > ,
419409) -> Result < Vec < CertificateDer < ' static > > , ProxyError > {
420410 let connector = TlsConnector :: from ( client_config) ;
421411
422- let out = TcpStream :: connect ( server_address) . await ?;
423- let mut tls_stream = connector. connect ( server_name, out) . await ?;
412+ let out = TcpStream :: connect ( host_to_host_with_port ( & server_name) ) . await ?;
413+ let mut tls_stream = connector
414+ . connect ( server_name_from_host ( & server_name) ?, out)
415+ . await ?;
424416
425417 let ( _io, server_connection) = tls_stream. get_ref ( ) ;
426418
@@ -467,6 +459,8 @@ pub enum ProxyError {
467459 Attestation ( #[ from] AttestationError ) ,
468460 #[ error( "Integer conversion: {0}" ) ]
469461 IntConversion ( #[ from] TryFromIntError ) ,
462+ #[ error( "Bad host name: {0}" ) ]
463+ BadDnsName ( #[ from] tokio_rustls:: rustls:: pki_types:: InvalidDnsNameError ) ,
470464}
471465
472466/// Given a byte array, encode its length as a 4 byte big endian u32
@@ -475,6 +469,27 @@ fn length_prefix(input: &[u8]) -> [u8; 4] {
475469 len. to_be_bytes ( )
476470}
477471
472+ fn host_to_host_with_port ( host : & str ) -> String {
473+ if host. contains ( ':' ) {
474+ host. to_string ( )
475+ } else {
476+ format ! ( "{host}:443" )
477+ }
478+ }
479+
480+ fn server_name_from_host (
481+ host : & str ,
482+ ) -> Result < ServerName < ' static > , tokio_rustls:: rustls:: pki_types:: InvalidDnsNameError > {
483+ // If host contains ':', try to split off the port.
484+ let host_part = host. rsplit_once ( ':' ) . map ( |( h, _) | h) . unwrap_or ( host) ;
485+
486+ // If the host is an IPv6 literal in brackets like "[::1]:443",
487+ // remove the brackets for SNI (SNI allows bare IPv6 too).
488+ let host_part = host_part. trim_matches ( |c| c == '[' || c == ']' ) ;
489+
490+ ServerName :: try_from ( host_part. to_string ( ) )
491+ }
492+
478493#[ cfg( test) ]
479494mod tests {
480495 use super :: * ;
@@ -486,9 +501,8 @@ mod tests {
486501 #[ tokio:: test]
487502 async fn http_proxy ( ) {
488503 let target_addr = example_http_service ( ) . await ;
489- let target_name = "name" . to_string ( ) ;
490504
491- let ( cert_chain, private_key) = generate_certificate_chain ( target_name . clone ( ) ) ;
505+ let ( cert_chain, private_key) = generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
492506 let ( server_config, client_config) = generate_tls_config ( cert_chain. clone ( ) , private_key) ;
493507
494508 let proxy_server = ProxyServer :: new_with_tls_config (
@@ -510,9 +524,8 @@ mod tests {
510524
511525 let proxy_client = ProxyClient :: new_with_tls_config (
512526 client_config,
513- "127.0.0.1:0" ,
514- proxy_addr,
515- target_name. try_into ( ) . unwrap ( ) ,
527+ "127.0.0.1:0" . to_string ( ) ,
528+ proxy_addr. to_string ( ) ,
516529 NoAttestation ,
517530 MockAttestation ,
518531 None ,
@@ -539,12 +552,11 @@ mod tests {
539552 #[ tokio:: test]
540553 async fn http_proxy_mutual_attestation ( ) {
541554 let target_addr = example_http_service ( ) . await ;
542- let target_name = "name" . to_string ( ) ;
543555
544556 let ( server_cert_chain, server_private_key) =
545- generate_certificate_chain ( target_name . clone ( ) ) ;
557+ generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
546558 let ( client_cert_chain, client_private_key) =
547- generate_certificate_chain ( target_name . clone ( ) ) ;
559+ generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
548560
549561 let (
550562 ( _client_tls_server_config, client_tls_client_config) ,
@@ -576,8 +588,7 @@ mod tests {
576588 let proxy_client = ProxyClient :: new_with_tls_config (
577589 client_tls_client_config,
578590 "127.0.0.1:0" ,
579- proxy_addr,
580- target_name. try_into ( ) . unwrap ( ) ,
591+ proxy_addr. to_string ( ) ,
581592 MockAttestation ,
582593 MockAttestation ,
583594 Some ( client_cert_chain) ,
@@ -604,9 +615,8 @@ mod tests {
604615 #[ tokio:: test]
605616 async fn raw_tcp_proxy ( ) {
606617 let target_addr = example_service ( ) . await ;
607- let target_name = "name" . to_string ( ) ;
608618
609- let ( cert_chain, private_key) = generate_certificate_chain ( target_name . clone ( ) ) ;
619+ let ( cert_chain, private_key) = generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
610620 let ( server_config, client_config) = generate_tls_config ( cert_chain. clone ( ) , private_key) ;
611621
612622 let local_attestation_platform = MockAttestation ;
@@ -631,8 +641,7 @@ mod tests {
631641 let proxy_client = ProxyClient :: new_with_tls_config (
632642 client_config,
633643 "127.0.0.1:0" ,
634- proxy_server_addr,
635- target_name. try_into ( ) . unwrap ( ) ,
644+ proxy_server_addr. to_string ( ) ,
636645 NoAttestation ,
637646 MockAttestation ,
638647 None ,
@@ -657,9 +666,8 @@ mod tests {
657666 #[ tokio:: test]
658667 async fn test_get_tls_cert ( ) {
659668 let target_addr = example_service ( ) . await ;
660- let target_name = "name" . to_string ( ) ;
661669
662- let ( cert_chain, private_key) = generate_certificate_chain ( target_name . clone ( ) ) ;
670+ let ( cert_chain, private_key) = generate_certificate_chain ( "127.0.0.1" . parse ( ) . unwrap ( ) ) ;
663671 let ( server_config, client_config) = generate_tls_config ( cert_chain. clone ( ) , private_key) ;
664672
665673 let local_attestation_platform = MockAttestation ;
@@ -682,8 +690,7 @@ mod tests {
682690 } ) ;
683691
684692 let retrieved_chain = get_tls_cert_with_config (
685- proxy_server_addr,
686- target_name. try_into ( ) . unwrap ( ) ,
693+ proxy_server_addr. to_string ( ) ,
687694 MockAttestation ,
688695 client_config,
689696 )
0 commit comments