2121use crate :: { error:: Error , quicksink, tls} ;
2222use either:: Either ;
2323use futures:: { future:: BoxFuture , prelude:: * , ready, stream:: BoxStream } ;
24- use futures_rustls:: { client, rustls, server} ;
24+ use futures_rustls:: rustls:: pki_types:: ServerName ;
25+ use futures_rustls:: { client, server} ;
2526use libp2p_core:: {
2627 multiaddr:: { Multiaddr , Protocol } ,
2728 transport:: { DialOpts , ListenerId , TransportError , TransportEvent } ,
@@ -32,6 +33,7 @@ use soketto::{
3233 connection:: { self , CloseReason } ,
3334 handshake,
3435} ;
36+ use std:: net:: IpAddr ;
3537use std:: { collections:: HashMap , ops:: DerefMut , sync:: Arc } ;
3638use std:: { fmt, io, mem, pin:: Pin , task:: Context , task:: Poll } ;
3739use url:: Url ;
@@ -315,15 +317,12 @@ where
315317
316318 let stream = if addr. use_tls {
317319 // begin TLS session
318- let dns_name = addr
319- . dns_name
320- . expect ( "for use_tls we have checked that dns_name is some" ) ;
321- tracing:: trace!( ?dns_name, "Starting TLS handshake" ) ;
320+ tracing:: trace!( ?addr. server_name, "Starting TLS handshake" ) ;
322321 let stream = tls_config
323322 . client
324- . connect ( dns_name . clone ( ) , stream)
323+ . connect ( addr . server_name . clone ( ) , stream)
325324 . map_err ( |e| {
326- tracing:: debug!( ?dns_name , "TLS handshake failed: {}" , e) ;
325+ tracing:: debug!( ?addr . server_name , "TLS handshake failed: {}" , e) ;
327326 Error :: Tls ( tls:: Error :: from ( e) )
328327 } )
329328 . await ?;
@@ -451,7 +450,7 @@ where
451450struct WsAddress {
452451 host_port : String ,
453452 path : String ,
454- dns_name : Option < rustls :: pki_types :: ServerName < ' static > > ,
453+ server_name : ServerName < ' static > ,
455454 use_tls : bool ,
456455 tcp_addr : Multiaddr ,
457456}
@@ -468,19 +467,21 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
468467 let mut protocols = addr. iter ( ) ;
469468 let mut ip = protocols. next ( ) ;
470469 let mut tcp = protocols. next ( ) ;
471- let ( host_port, dns_name ) = loop {
470+ let ( host_port, server_name ) = loop {
472471 match ( ip, tcp) {
473472 ( Some ( Protocol :: Ip4 ( ip) ) , Some ( Protocol :: Tcp ( port) ) ) => {
474- break ( format ! ( "{ip}:{port}" ) , None )
473+ let server_name = ServerName :: IpAddress ( IpAddr :: V4 ( ip) . into ( ) ) ;
474+ break ( format ! ( "{ip}:{port}" ) , server_name) ;
475475 }
476476 ( Some ( Protocol :: Ip6 ( ip) ) , Some ( Protocol :: Tcp ( port) ) ) => {
477- break ( format ! ( "{ip}:{port}" ) , None )
477+ let server_name = ServerName :: IpAddress ( IpAddr :: V6 ( ip) . into ( ) ) ;
478+ break ( format ! ( "[{ip}]:{port}" ) , server_name) ;
478479 }
479480 ( Some ( Protocol :: Dns ( h) ) , Some ( Protocol :: Tcp ( port) ) )
480481 | ( Some ( Protocol :: Dns4 ( h) ) , Some ( Protocol :: Tcp ( port) ) )
481482 | ( Some ( Protocol :: Dns6 ( h) ) , Some ( Protocol :: Tcp ( port) ) )
482483 | ( Some ( Protocol :: Dnsaddr ( h) ) , Some ( Protocol :: Tcp ( port) ) ) => {
483- break ( format ! ( "{}:{}" , & h , port ) , Some ( tls:: dns_name_ref ( & h) ?) )
484+ break ( format ! ( "{h }:{port}" ) , tls:: dns_name_ref ( & h) ?)
484485 }
485486 ( Some ( _) , Some ( p) ) => {
486487 ip = Some ( p) ;
@@ -499,13 +500,7 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
499500 match protocols. pop ( ) {
500501 p @ Some ( Protocol :: P2p ( _) ) => p2p = p,
501502 Some ( Protocol :: Ws ( path) ) => break ( false , path. into_owned ( ) ) ,
502- Some ( Protocol :: Wss ( path) ) => {
503- if dns_name. is_none ( ) {
504- tracing:: debug!( address=%addr, "Missing DNS name in WSS address" ) ;
505- return Err ( Error :: InvalidMultiaddr ( addr) ) ;
506- }
507- break ( true , path. into_owned ( ) ) ;
508- }
503+ Some ( Protocol :: Wss ( path) ) => break ( true , path. into_owned ( ) ) ,
509504 _ => return Err ( Error :: InvalidMultiaddr ( addr) ) ,
510505 }
511506 } ;
@@ -519,7 +514,7 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
519514
520515 Ok ( WsAddress {
521516 host_port,
522- dns_name ,
517+ server_name ,
523518 path,
524519 use_tls,
525520 tcp_addr,
@@ -757,3 +752,109 @@ where
757752 . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: Other , e) )
758753 }
759754}
755+
756+ #[ cfg( test) ]
757+ mod tests {
758+ use super :: * ;
759+ use libp2p_identity:: PeerId ;
760+ use std:: io;
761+
762+ #[ test]
763+ fn dial_addr ( ) {
764+ let peer_id = PeerId :: random ( ) ;
765+
766+ // Check `/wss`
767+ let addr = "/dns4/example.com/tcp/2222/wss"
768+ . parse :: < Multiaddr > ( )
769+ . unwrap ( ) ;
770+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
771+ assert_eq ! ( info. host_port, "example.com:2222" ) ;
772+ assert_eq ! ( info. path, "/" ) ;
773+ assert ! ( info. use_tls) ;
774+ assert_eq ! ( info. server_name, "example.com" . try_into( ) . unwrap( ) ) ;
775+ assert_eq ! ( info. tcp_addr, "/dns4/example.com/tcp/2222" . parse( ) . unwrap( ) ) ;
776+
777+ // Check `/wss` with `/p2p`
778+ let addr = format ! ( "/dns4/example.com/tcp/2222/wss/p2p/{peer_id}" )
779+ . parse ( )
780+ . unwrap ( ) ;
781+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
782+ assert_eq ! ( info. host_port, "example.com:2222" ) ;
783+ assert_eq ! ( info. path, "/" ) ;
784+ assert ! ( info. use_tls) ;
785+ assert_eq ! ( info. server_name, "example.com" . try_into( ) . unwrap( ) ) ;
786+ assert_eq ! (
787+ info. tcp_addr,
788+ format!( "/dns4/example.com/tcp/2222/p2p/{peer_id}" )
789+ . parse( )
790+ . unwrap( )
791+ ) ;
792+
793+ // Check `/wss` with `/ip4`
794+ let addr = "/ip4/127.0.0.1/tcp/2222/wss" . parse :: < Multiaddr > ( ) . unwrap ( ) ;
795+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
796+ assert_eq ! ( info. host_port, "127.0.0.1:2222" ) ;
797+ assert_eq ! ( info. path, "/" ) ;
798+ assert ! ( info. use_tls) ;
799+ assert_eq ! ( info. server_name, "127.0.0.1" . try_into( ) . unwrap( ) ) ;
800+ assert_eq ! ( info. tcp_addr, "/ip4/127.0.0.1/tcp/2222" . parse( ) . unwrap( ) ) ;
801+
802+ // Check `/wss` with `/ip6`
803+ let addr = "/ip6/::1/tcp/2222/wss" . parse :: < Multiaddr > ( ) . unwrap ( ) ;
804+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
805+ assert_eq ! ( info. host_port, "[::1]:2222" ) ;
806+ assert_eq ! ( info. path, "/" ) ;
807+ assert ! ( info. use_tls) ;
808+ assert_eq ! ( info. server_name, "::1" . try_into( ) . unwrap( ) ) ;
809+ assert_eq ! ( info. tcp_addr, "/ip6/::1/tcp/2222" . parse( ) . unwrap( ) ) ;
810+
811+ // Check `/ws`
812+ let addr = "/dns4/example.com/tcp/2222/ws"
813+ . parse :: < Multiaddr > ( )
814+ . unwrap ( ) ;
815+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
816+ assert_eq ! ( info. host_port, "example.com:2222" ) ;
817+ assert_eq ! ( info. path, "/" ) ;
818+ assert ! ( !info. use_tls) ;
819+ assert_eq ! ( info. server_name, "example.com" . try_into( ) . unwrap( ) ) ;
820+ assert_eq ! ( info. tcp_addr, "/dns4/example.com/tcp/2222" . parse( ) . unwrap( ) ) ;
821+
822+ // Check `/ws` with `/p2p`
823+ let addr = format ! ( "/dns4/example.com/tcp/2222/ws/p2p/{peer_id}" )
824+ . parse ( )
825+ . unwrap ( ) ;
826+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
827+ assert_eq ! ( info. host_port, "example.com:2222" ) ;
828+ assert_eq ! ( info. path, "/" ) ;
829+ assert ! ( !info. use_tls) ;
830+ assert_eq ! ( info. server_name, "example.com" . try_into( ) . unwrap( ) ) ;
831+ assert_eq ! (
832+ info. tcp_addr,
833+ format!( "/dns4/example.com/tcp/2222/p2p/{peer_id}" )
834+ . parse( )
835+ . unwrap( )
836+ ) ;
837+
838+ // Check `/ws` with `/ip4`
839+ let addr = "/ip4/127.0.0.1/tcp/2222/ws" . parse :: < Multiaddr > ( ) . unwrap ( ) ;
840+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
841+ assert_eq ! ( info. host_port, "127.0.0.1:2222" ) ;
842+ assert_eq ! ( info. path, "/" ) ;
843+ assert ! ( !info. use_tls) ;
844+ assert_eq ! ( info. server_name, "127.0.0.1" . try_into( ) . unwrap( ) ) ;
845+ assert_eq ! ( info. tcp_addr, "/ip4/127.0.0.1/tcp/2222" . parse( ) . unwrap( ) ) ;
846+
847+ // Check `/ws` with `/ip6`
848+ let addr = "/ip6/::1/tcp/2222/ws" . parse :: < Multiaddr > ( ) . unwrap ( ) ;
849+ let info = parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap ( ) ;
850+ assert_eq ! ( info. host_port, "[::1]:2222" ) ;
851+ assert_eq ! ( info. path, "/" ) ;
852+ assert ! ( !info. use_tls) ;
853+ assert_eq ! ( info. server_name, "::1" . try_into( ) . unwrap( ) ) ;
854+ assert_eq ! ( info. tcp_addr, "/ip6/::1/tcp/2222" . parse( ) . unwrap( ) ) ;
855+
856+ // Check non-ws address
857+ let addr = "/ip4/127.0.0.1/tcp/2222" . parse :: < Multiaddr > ( ) . unwrap ( ) ;
858+ parse_ws_dial_addr :: < io:: Error > ( addr) . unwrap_err ( ) ;
859+ }
860+ }
0 commit comments