Skip to content

Commit 98da34a

Browse files
authored
feat(websocket): Allow wss connections on IP addresses
Pull-Request: #5525.
1 parent 823acd6 commit 98da34a

File tree

2 files changed

+124
-20
lines changed

2 files changed

+124
-20
lines changed

transports/websocket/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
- Implement refactored `Transport`.
44
See [PR 4568](https://github.com/libp2p/rust-libp2p/pull/4568)
5+
- Allow wss connections on IP addresses.
6+
See [PR 5525](https://github.com/libp2p/rust-libp2p/pull/5525).
7+
58
## 0.43.2
69

710
- fix: Avoid websocket panic on polling after errors. See [PR 5482].

transports/websocket/src/framed.rs

Lines changed: 121 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
use crate::{error::Error, quicksink, tls};
2222
use either::Either;
2323
use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
24-
use futures_rustls::{client, rustls, server};
24+
use futures_rustls::rustls::pki_types::ServerName;
25+
use futures_rustls::{client, server};
2526
use libp2p_core::{
2627
multiaddr::{Multiaddr, Protocol},
2728
transport::{DialOpts, ListenerId, TransportError, TransportEvent},
@@ -32,6 +33,7 @@ use soketto::{
3233
connection::{self, CloseReason},
3334
handshake,
3435
};
36+
use std::net::IpAddr;
3537
use std::{collections::HashMap, ops::DerefMut, sync::Arc};
3638
use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
3739
use url::Url;
@@ -315,15 +317,12 @@ where
315317

316318
let stream = if addr.use_tls {
317319
// begin TLS session
318-
let dns_name = addr
319-
.dns_name
320-
.expect("for use_tls we have checked that dns_name is some");
321-
tracing::trace!(?dns_name, "Starting TLS handshake");
320+
tracing::trace!(?addr.server_name, "Starting TLS handshake");
322321
let stream = tls_config
323322
.client
324-
.connect(dns_name.clone(), stream)
323+
.connect(addr.server_name.clone(), stream)
325324
.map_err(|e| {
326-
tracing::debug!(?dns_name, "TLS handshake failed: {}", e);
325+
tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
327326
Error::Tls(tls::Error::from(e))
328327
})
329328
.await?;
@@ -451,7 +450,7 @@ where
451450
struct WsAddress {
452451
host_port: String,
453452
path: String,
454-
dns_name: Option<rustls::pki_types::ServerName<'static>>,
453+
server_name: ServerName<'static>,
455454
use_tls: bool,
456455
tcp_addr: Multiaddr,
457456
}
@@ -468,19 +467,21 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
468467
let mut protocols = addr.iter();
469468
let mut ip = protocols.next();
470469
let mut tcp = protocols.next();
471-
let (host_port, dns_name) = loop {
470+
let (host_port, server_name) = loop {
472471
match (ip, tcp) {
473472
(Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
474-
break (format!("{ip}:{port}"), None)
473+
let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
474+
break (format!("{ip}:{port}"), server_name);
475475
}
476476
(Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
477-
break (format!("{ip}:{port}"), None)
477+
let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
478+
break (format!("[{ip}]:{port}"), server_name);
478479
}
479480
(Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
480481
| (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
481482
| (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port)))
482483
| (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => {
483-
break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?))
484+
break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
484485
}
485486
(Some(_), Some(p)) => {
486487
ip = Some(p);
@@ -499,13 +500,7 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
499500
match protocols.pop() {
500501
p @ Some(Protocol::P2p(_)) => p2p = p,
501502
Some(Protocol::Ws(path)) => break (false, path.into_owned()),
502-
Some(Protocol::Wss(path)) => {
503-
if dns_name.is_none() {
504-
tracing::debug!(address=%addr, "Missing DNS name in WSS address");
505-
return Err(Error::InvalidMultiaddr(addr));
506-
}
507-
break (true, path.into_owned());
508-
}
503+
Some(Protocol::Wss(path)) => break (true, path.into_owned()),
509504
_ => return Err(Error::InvalidMultiaddr(addr)),
510505
}
511506
};
@@ -519,7 +514,7 @@ fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
519514

520515
Ok(WsAddress {
521516
host_port,
522-
dns_name,
517+
server_name,
523518
path,
524519
use_tls,
525520
tcp_addr,
@@ -757,3 +752,109 @@ where
757752
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
758753
}
759754
}
755+
756+
#[cfg(test)]
757+
mod tests {
758+
use super::*;
759+
use libp2p_identity::PeerId;
760+
use std::io;
761+
762+
#[test]
763+
fn dial_addr() {
764+
let peer_id = PeerId::random();
765+
766+
// Check `/wss`
767+
let addr = "/dns4/example.com/tcp/2222/wss"
768+
.parse::<Multiaddr>()
769+
.unwrap();
770+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
771+
assert_eq!(info.host_port, "example.com:2222");
772+
assert_eq!(info.path, "/");
773+
assert!(info.use_tls);
774+
assert_eq!(info.server_name, "example.com".try_into().unwrap());
775+
assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
776+
777+
// Check `/wss` with `/p2p`
778+
let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
779+
.parse()
780+
.unwrap();
781+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
782+
assert_eq!(info.host_port, "example.com:2222");
783+
assert_eq!(info.path, "/");
784+
assert!(info.use_tls);
785+
assert_eq!(info.server_name, "example.com".try_into().unwrap());
786+
assert_eq!(
787+
info.tcp_addr,
788+
format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
789+
.parse()
790+
.unwrap()
791+
);
792+
793+
// Check `/wss` with `/ip4`
794+
let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
795+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
796+
assert_eq!(info.host_port, "127.0.0.1:2222");
797+
assert_eq!(info.path, "/");
798+
assert!(info.use_tls);
799+
assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
800+
assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
801+
802+
// Check `/wss` with `/ip6`
803+
let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
804+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
805+
assert_eq!(info.host_port, "[::1]:2222");
806+
assert_eq!(info.path, "/");
807+
assert!(info.use_tls);
808+
assert_eq!(info.server_name, "::1".try_into().unwrap());
809+
assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
810+
811+
// Check `/ws`
812+
let addr = "/dns4/example.com/tcp/2222/ws"
813+
.parse::<Multiaddr>()
814+
.unwrap();
815+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
816+
assert_eq!(info.host_port, "example.com:2222");
817+
assert_eq!(info.path, "/");
818+
assert!(!info.use_tls);
819+
assert_eq!(info.server_name, "example.com".try_into().unwrap());
820+
assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
821+
822+
// Check `/ws` with `/p2p`
823+
let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
824+
.parse()
825+
.unwrap();
826+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
827+
assert_eq!(info.host_port, "example.com:2222");
828+
assert_eq!(info.path, "/");
829+
assert!(!info.use_tls);
830+
assert_eq!(info.server_name, "example.com".try_into().unwrap());
831+
assert_eq!(
832+
info.tcp_addr,
833+
format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
834+
.parse()
835+
.unwrap()
836+
);
837+
838+
// Check `/ws` with `/ip4`
839+
let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
840+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
841+
assert_eq!(info.host_port, "127.0.0.1:2222");
842+
assert_eq!(info.path, "/");
843+
assert!(!info.use_tls);
844+
assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
845+
assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
846+
847+
// Check `/ws` with `/ip6`
848+
let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
849+
let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
850+
assert_eq!(info.host_port, "[::1]:2222");
851+
assert_eq!(info.path, "/");
852+
assert!(!info.use_tls);
853+
assert_eq!(info.server_name, "::1".try_into().unwrap());
854+
assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
855+
856+
// Check non-ws address
857+
let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
858+
parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
859+
}
860+
}

0 commit comments

Comments
 (0)