diff --git a/protocols/autonat/tests/autonatv2.rs b/protocols/autonat/tests/autonatv2.rs index 5834420b431..53cee508dd5 100644 --- a/protocols/autonat/tests/autonatv2.rs +++ b/protocols/autonat/tests/autonatv2.rs @@ -324,7 +324,6 @@ async fn dial_back_to_not_supporting() { let (bob_done_tx, bob_done_rx) = oneshot::channel(); let hannes = new_dummy().await; - let hannes_peer_id = *hannes.local_peer_id(); let unreachable_address = hannes.external_addresses().next().unwrap().clone(); let bob_unreachable_address = unreachable_address.clone(); bob.behaviour_mut() @@ -338,7 +337,7 @@ async fn dial_back_to_not_supporting() { let handler = tokio::spawn(async { hannes.loop_on_next().await }); let alice_task = async { - let (alice_dialing_peer, alice_conn_id) = alice + let (alice_dialing_peer, _) = alice .wait(|event| match event { SwarmEvent::Dialing { peer_id, @@ -350,15 +349,9 @@ async fn dial_back_to_not_supporting() { alice .wait(|event| match event { SwarmEvent::OutgoingConnectionError { - connection_id, - peer_id: Some(peer_id), - error: DialError::WrongPeerId { obtained, .. }, - } if connection_id == alice_conn_id - && peer_id == alice_dialing_peer - && obtained == hannes_peer_id => - { - Some(()) - } + error: DialError::Transport(_), + .. + } => Some(()), _ => None, }) .await; diff --git a/swarm/src/connection/pool.rs b/swarm/src/connection/pool.rs index 37ae63af033..2f533040566 100644 --- a/swarm/src/connection/pool.rs +++ b/swarm/src/connection/pool.rs @@ -438,7 +438,7 @@ where self.executor.spawn( task::new_for_pending_outgoing_connection( connection_id, - ConcurrentDial::new(dials, concurrency_factor), + ConcurrentDial::new(peer, dials, concurrency_factor), abort_receiver, self.pending_connection_events_tx.clone(), ) diff --git a/swarm/src/connection/pool/concurrent_dial.rs b/swarm/src/connection/pool/concurrent_dial.rs index 99f0b385884..0f6450cd820 100644 --- a/swarm/src/connection/pool/concurrent_dial.rs +++ b/swarm/src/connection/pool/concurrent_dial.rs @@ -43,6 +43,7 @@ type Dial = BoxFuture< >; pub(crate) struct ConcurrentDial { + peer_id: Option, dials: FuturesUnordered, pending_dials: Box + Send>, errors: Vec<(Multiaddr, TransportError)>, @@ -51,7 +52,11 @@ pub(crate) struct ConcurrentDial { impl Unpin for ConcurrentDial {} impl ConcurrentDial { - pub(crate) fn new(pending_dials: Vec, concurrency_factor: NonZeroU8) -> Self { + pub(crate) fn new( + peer_id: Option, + pending_dials: Vec, + concurrency_factor: NonZeroU8, + ) -> Self { let mut pending_dials = pending_dials.into_iter(); let dials = FuturesUnordered::new(); @@ -63,6 +68,7 @@ impl ConcurrentDial { } Self { + peer_id, dials, errors: Default::default(), pending_dials: Box::new(pending_dials), @@ -86,10 +92,14 @@ impl Future for ConcurrentDial { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { loop { match ready!(self.dials.poll_next_unpin(cx)) { - Some((addr, Ok(output))) => { + Some((addr, Ok(output))) if self.peer_id.is_none_or(|id| output.0 == id) => { let errors = std::mem::take(&mut self.errors); return Poll::Ready(Ok((addr, output, errors))); } + Some((addr, Ok(_))) => { + let e = TransportError::Other(std::io::ErrorKind::PermissionDenied.into()); + self.errors.push((addr, e)); + } Some((addr, Err(e))) => { self.errors.push((addr, e)); if let Some(dial) = self.pending_dials.next() { diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index d0ae6118190..9baa1504e4b 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -1758,6 +1758,8 @@ impl NetworkInfo { #[cfg(test)] mod tests { + use std::io::ErrorKind; + use libp2p_core::{ multiaddr, multiaddr::multiaddr, @@ -2153,10 +2155,13 @@ mod tests { .await; assert_eq!(peer_id.unwrap(), other_id); match error { - DialError::WrongPeerId { obtained, address } => { - assert_eq!(obtained, *swarm1.local_peer_id()); - assert_eq!(address, other_addr); - } + DialError::Transport(e) => match &e.get(0).unwrap().1 { + TransportError::Other(e) => match e.kind() { + ErrorKind::PermissionDenied => {} + _ => panic!("wrong error {e:?}"), + }, + _ => panic!("wrong error {e:?}"), + }, x => panic!("wrong error {x:?}"), } }