Skip to content

Commit bf7fe68

Browse files
authored
refactor(tcp): reducing branching in Transport::create_socket
Following #4289 (comment), hereby is the PR to also improve the `create_socket` using [`for_addr`](https://docs.rs/socket2/latest/socket2/struct.Domain.html#method.for_address). We also add a test for listening on IPv4 and IPv6 separately. Pull-Request: #4328.
1 parent 08292c5 commit bf7fe68

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

transports/quic/src/transport.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,17 +926,21 @@ mod tests {
926926
let keypair = libp2p_identity::Keypair::generate_ed25519();
927927
let config = Config::new(&keypair);
928928
let mut transport = crate::tokio::Transport::new(config);
929+
let port = {
930+
let socket = UdpSocket::bind("127.0.0.1:0").unwrap();
931+
socket.local_addr().unwrap().port()
932+
};
929933

930934
transport
931935
.listen_on(
932936
ListenerId::next(),
933-
"/ip4/0.0.0.0/udp/4001/quic-v1".parse().unwrap(),
937+
format!("/ip4/0.0.0.0/udp/{port}/quic-v1").parse().unwrap(),
934938
)
935939
.unwrap();
936940
transport
937941
.listen_on(
938942
ListenerId::next(),
939-
"/ip6/::/udp/4001/quic-v1".parse().unwrap(),
943+
format!("/ip6/::/udp/{port}/quic-v1").parse().unwrap(),
940944
)
941945
.unwrap();
942946
}

transports/tcp/src/lib.rs

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,12 @@ where
346346
}
347347
}
348348

349-
fn create_socket(&self, socket_addr: &SocketAddr) -> io::Result<Socket> {
350-
let domain = if socket_addr.is_ipv4() {
351-
Domain::IPV4
352-
} else {
353-
Domain::IPV6
354-
};
355-
let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
349+
fn create_socket(&self, socket_addr: SocketAddr) -> io::Result<Socket> {
350+
let socket = Socket::new(
351+
Domain::for_address(socket_addr),
352+
Type::STREAM,
353+
Some(socket2::Protocol::TCP),
354+
)?;
356355
if socket_addr.is_ipv6() {
357356
socket.set_only_v6(true)?;
358357
}
@@ -375,7 +374,7 @@ where
375374
id: ListenerId,
376375
socket_addr: SocketAddr,
377376
) -> io::Result<ListenStream<T>> {
378-
let socket = self.create_socket(&socket_addr)?;
377+
let socket = self.create_socket(socket_addr)?;
379378
socket.bind(&socket_addr.into())?;
380379
socket.listen(self.config.backlog as _)?;
381380
socket.set_nonblocking(true)?;
@@ -476,7 +475,7 @@ where
476475
log::debug!("dialing {}", socket_addr);
477476

478477
let socket = self
479-
.create_socket(&socket_addr)
478+
.create_socket(socket_addr)
480479
.map_err(TransportError::Other)?;
481480

482481
if let Some(addr) = self.port_reuse.local_dial_addr(&socket_addr.ip()) {
@@ -1329,4 +1328,42 @@ mod tests {
13291328
assert!(rt.block_on(cycle_listeners::<tokio::Tcp>()));
13301329
}
13311330
}
1331+
1332+
#[test]
1333+
fn test_listens_ipv4_ipv6_separately() {
1334+
fn test<T: Provider>() {
1335+
let port = {
1336+
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1337+
listener.local_addr().unwrap().port()
1338+
};
1339+
let mut tcp = Transport::<T>::default().boxed();
1340+
let listener_id = ListenerId::next();
1341+
tcp.listen_on(
1342+
listener_id,
1343+
format!("/ip4/0.0.0.0/tcp/{port}").parse().unwrap(),
1344+
)
1345+
.unwrap();
1346+
tcp.listen_on(
1347+
ListenerId::next(),
1348+
format!("/ip6/::/tcp/{port}").parse().unwrap(),
1349+
)
1350+
.unwrap();
1351+
}
1352+
#[cfg(feature = "async-io")]
1353+
{
1354+
async_std::task::block_on(async {
1355+
test::<async_io::Tcp>();
1356+
})
1357+
}
1358+
#[cfg(feature = "tokio")]
1359+
{
1360+
let rt = ::tokio::runtime::Builder::new_current_thread()
1361+
.enable_io()
1362+
.build()
1363+
.unwrap();
1364+
rt.block_on(async {
1365+
test::<async_io::Tcp>();
1366+
});
1367+
}
1368+
}
13321369
}

0 commit comments

Comments
 (0)