diff --git a/p2p/dial.go b/p2p/dial.go index 225709427c2..f9463d6d890 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -76,6 +76,7 @@ var ( errSelf = errors.New("is self") errAlreadyDialing = errors.New("already dialing") errAlreadyConnected = errors.New("already connected") + errPendingInbound = errors.New("peer has pending inbound connection") errRecentlyDialed = errors.New("recently dialed") errNetRestrict = errors.New("not contained in netrestrict list") errNoPort = errors.New("node does not provide TCP port") @@ -104,12 +105,15 @@ type dialScheduler struct { remStaticCh chan *enode.Node addPeerCh chan *conn remPeerCh chan *conn + addPendingCh chan enode.ID + remPendingCh chan enode.ID // Everything below here belongs to loop and // should only be accessed by code on the loop goroutine. - dialing map[enode.ID]*dialTask // active tasks - peers map[enode.ID]struct{} // all connected peers - dialPeers int // current number of dialed peers + dialing map[enode.ID]*dialTask // active tasks + peers map[enode.ID]struct{} // all connected peers + pendingInbound map[enode.ID]struct{} // in-progress inbound connections + dialPeers int // current number of dialed peers // The static map tracks all static dial tasks. The subset of usable static dial tasks // (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers @@ -163,19 +167,22 @@ func (cfg dialConfig) withDefaults() dialConfig { func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { cfg := config.withDefaults() d := &dialScheduler{ - dialConfig: cfg, - historyTimer: mclock.NewAlarm(cfg.clock), - setupFunc: setupFunc, - dnsLookupFunc: net.DefaultResolver.LookupNetIP, - dialing: make(map[enode.ID]*dialTask), - static: make(map[enode.ID]*dialTask), - peers: make(map[enode.ID]struct{}), - doneCh: make(chan *dialTask), - nodesIn: make(chan *enode.Node), - addStaticCh: make(chan *enode.Node), - remStaticCh: make(chan *enode.Node), - addPeerCh: make(chan *conn), - remPeerCh: make(chan *conn), + dialConfig: cfg, + historyTimer: mclock.NewAlarm(cfg.clock), + setupFunc: setupFunc, + dnsLookupFunc: net.DefaultResolver.LookupNetIP, + dialing: make(map[enode.ID]*dialTask), + static: make(map[enode.ID]*dialTask), + peers: make(map[enode.ID]struct{}), + pendingInbound: make(map[enode.ID]struct{}), + doneCh: make(chan *dialTask), + nodesIn: make(chan *enode.Node), + addStaticCh: make(chan *enode.Node), + remStaticCh: make(chan *enode.Node), + addPeerCh: make(chan *conn), + remPeerCh: make(chan *conn), + addPendingCh: make(chan enode.ID), + remPendingCh: make(chan enode.ID), } d.lastStatsLog = d.clock.Now() d.ctx, d.cancel = context.WithCancel(context.Background()) @@ -223,6 +230,22 @@ func (d *dialScheduler) peerRemoved(c *conn) { } } +// inboundPending notifies the scheduler about a pending inbound connection. +func (d *dialScheduler) inboundPending(id enode.ID) { + select { + case d.addPendingCh <- id: + case <-d.ctx.Done(): + } +} + +// inboundCompleted notifies the scheduler that an inbound connection completed or failed. +func (d *dialScheduler) inboundCompleted(id enode.ID) { + select { + case d.remPendingCh <- id: + case <-d.ctx.Done(): + } +} + // loop is the main loop of the dialer. func (d *dialScheduler) loop(it enode.Iterator) { var ( @@ -276,6 +299,15 @@ loop: delete(d.peers, c.node.ID()) d.updateStaticPool(c.node.ID()) + case id := <-d.addPendingCh: + d.pendingInbound[id] = struct{}{} + d.log.Trace("Marked node as pending inbound", "id", id) + + case id := <-d.remPendingCh: + delete(d.pendingInbound, id) + d.updateStaticPool(id) + d.log.Trace("Unmarked node as pending inbound", "id", id) + case node := <-d.addStaticCh: id := node.ID() _, exists := d.static[id] @@ -390,6 +422,9 @@ func (d *dialScheduler) checkDial(n *enode.Node) error { if _, ok := d.peers[n.ID()]; ok { return errAlreadyConnected } + if _, ok := d.pendingInbound[n.ID()]; ok { + return errPendingInbound + } if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) { return errNetRestrict } diff --git a/p2p/dial_test.go b/p2p/dial_test.go index f18dacce2ab..9684aa6e91f 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -423,6 +423,82 @@ func TestDialSchedDNSHostname(t *testing.T) { }) } +// This test checks that nodes with pending inbound connections are not dialed. +func TestDialSchedPendingInbound(t *testing.T) { + t.Parallel() + + config := dialConfig{ + maxActiveDials: 5, + maxDialPeers: 4, + } + runDialTest(t, config, []dialTestRound{ + // 2 peers are connected, leaving 2 dial slots. + // Node 0x03 has a pending inbound connection. + // Discovered nodes 0x03, 0x04, 0x05 but only 0x04 and 0x05 should be dialed. + { + peersAdded: []*conn{ + {flags: dynDialedConn, node: newNode(uintID(0x01), "127.0.0.1:30303")}, + {flags: dynDialedConn, node: newNode(uintID(0x02), "127.0.0.2:30303")}, + }, + update: func(d *dialScheduler) { + d.inboundPending(uintID(0x03)) + }, + discovered: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), // not dialed because pending inbound + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.0.5:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x04), "127.0.0.4:30303"), + newNode(uintID(0x05), "127.0.0.5:30303"), + }, + }, + // Pending inbound connection for 0x03 completes successfully. + // Node 0x03 becomes a connected peer. + // One dial slot remains, node 0x06 is dialed. + { + update: func(d *dialScheduler) { + // Pending inbound completes + d.inboundCompleted(uintID(0x03)) + }, + peersAdded: []*conn{ + {flags: inboundConn, node: newNode(uintID(0x03), "127.0.0.3:30303")}, + }, + succeeded: []enode.ID{ + uintID(0x04), + }, + failed: []enode.ID{ + uintID(0x05), + }, + discovered: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), // not dialed, now connected + newNode(uintID(0x06), "127.0.0.6:30303"), + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x06), "127.0.0.6:30303"), + }, + }, + // Inbound peer 0x03 disconnects. + // Another pending inbound starts for 0x07. + // Only 0x03 should be dialed, not 0x07. + { + peersRemoved: []enode.ID{ + uintID(0x03), + }, + update: func(d *dialScheduler) { + d.inboundPending(uintID(0x07)) + }, + discovered: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), + newNode(uintID(0x07), "127.0.0.7:30303"), // not dialed because pending inbound + }, + wantNewDials: []*enode.Node{ + newNode(uintID(0x03), "127.0.0.3:30303"), + }, + }, + }) +} + // ------- // Code below here is the framework for the tests above. diff --git a/p2p/server.go b/p2p/server.go index 10c855f1c46..397fea07097 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -639,10 +639,12 @@ func (srv *Server) run() { defer srv.dialsched.stop() var ( - peers = make(map[enode.ID]*Peer) - inboundCount = 0 - trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) + peers = make(map[enode.ID]*Peer) + inboundCount = 0 + trusted = make(map[enode.ID]bool, len(srv.TrustedNodes)) + pendingInbound = make(map[enode.ID]time.Time) // Track in-progress inbound connections ) + // Put trusted nodes into a map to speed up checks. // Trusted peers are loaded on startup or added via AddTrustedPeer RPC. for _, n := range srv.TrustedNodes { @@ -682,22 +684,65 @@ running: case c := <-srv.checkpointPostHandshake: // A connection has passed the encryption handshake so // the remote identity is known (but hasn't been verified yet). - if trusted[c.node.ID()] { + nodeID := c.node.ID() + if trusted[nodeID] { // Ensure that the trusted flag is set before checking against MaxPeers. c.flags |= trustedConn } - // TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them. - c.cont <- srv.postHandshakeChecks(peers, inboundCount, c) + + // Check for duplicate connections: both active peers and in-progress inbound connections. + if c.flags&inboundConn != 0 && c.flags&trustedConn == 0 { + // Check if we already have this peer or if there's already an in-progress inbound connection from them. + if _, exists := peers[nodeID]; exists { + srv.log.Debug("Rejecting duplicate inbound connection (peer exists)", "id", nodeID) + c.cont <- DiscAlreadyConnected + continue running + } + + if startTime, exists := pendingInbound[nodeID]; exists { + srv.log.Debug("Rejecting duplicate inbound connection (already pending)", + "id", nodeID, "pending_duration", time.Since(startTime)) + c.cont <- DiscAlreadyConnected + continue running + + } + + pendingInbound[nodeID] = time.Now() + srv.dialsched.inboundPending(nodeID) + srv.log.Trace("Tracking pending inbound connection", "id", nodeID, "pending_count", len(pendingInbound)) + } + + err := srv.postHandshakeChecks(peers, inboundCount, c) + if err != nil && c.flags&inboundConn != 0 && c.flags&trustedConn == 0 { + delete(pendingInbound, nodeID) + srv.dialsched.inboundCompleted(nodeID) + srv.log.Trace("Removed failed pending inbound connection", "id", nodeID, "err", err) + } + c.cont <- err case c := <-srv.checkpointAddPeer: // At this point the connection is past the protocol handshake. // Its capabilities are known and the remote identity is verified. + nodeID := c.node.ID() err := srv.addPeerChecks(peers, inboundCount, c) if err == nil { // The handshakes are done and it passed all checks. p := srv.launchPeer(c) - peers[c.node.ID()] = p - srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name()) + peers[nodeID] = p + // Remove from pending tracker as it became promoted to proper peer + if c.flags&inboundConn != 0 { + if startTime, exists := pendingInbound[nodeID]; exists { + duration := time.Since(startTime) + delete(pendingInbound, nodeID) + srv.dialsched.inboundCompleted(nodeID) + srv.log.Trace("Promoted pending inbound to peer", "id", nodeID, + "handshake_duration", duration, "pending_count", len(pendingInbound)) + } + + } + + srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), + "conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name()) srv.dialsched.peerAdded(c) if p.Inbound() { inboundCount++ @@ -708,14 +753,34 @@ running: activeOutboundPeerGauge.Inc(1) } activePeerGauge.Inc(1) + + } else { + // Failed to add peer. Clean up pending tracking if it was inbound. + if c.flags&inboundConn != 0 { + delete(pendingInbound, nodeID) + srv.dialsched.inboundCompleted(nodeID) + srv.log.Trace("Removed failed pending inbound at add peer stage", + "id", nodeID, "err", err) + } + } + c.cont <- err case pd := <-srv.delpeer: // A peer disconnected. d := common.PrettyDuration(mclock.Now() - pd.created) - delete(peers, pd.ID()) - srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", pd.ID(), "duration", d, "req", pd.requested, "err", pd.err) + nodeID := pd.ID() + delete(peers, nodeID) + // Remove from pending tracking if present (defensive cleanup). + if _, exists := pendingInbound[nodeID]; exists { + delete(pendingInbound, nodeID) + srv.dialsched.inboundCompleted(nodeID) + srv.log.Trace("Cleaned up pending entry on peer deletion", "id", nodeID) + } + + srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", nodeID, + "duration", d, "req", pd.requested, "err", pd.err) srv.dialsched.peerRemoved(pd.rw) if pd.Inbound() { inboundCount-- @@ -747,6 +812,7 @@ running: p := <-srv.delpeer p.log.Trace("<-delpeer (spindown)") delete(peers, p.ID()) + delete(pendingInbound, p.ID()) } } diff --git a/p2p/server_test.go b/p2p/server_test.go index 7bc7379099d..ff8a7e202b2 100644 --- a/p2p/server_test.go +++ b/p2p/server_test.go @@ -463,6 +463,109 @@ func TestServerSetupConn(t *testing.T) { } } +// TestServerPendingInboundRejection checks that duplicate inbound connections +// are rejected when there's already a pending inbound connection from the same peer. +func TestServerPendingInboundRejection(t *testing.T) { + trustedNode := newkey() + srv := &Server{ + Config: Config{ + PrivateKey: newkey(), + MaxPeers: 10, + NoDial: true, + NoDiscovery: true, + Logger: testlog.Logger(t, log.LvlTrace), + }, + } + if err := srv.Start(); err != nil { + t.Fatalf("could not start: %v", err) + } + defer srv.Stop() + + newconn := func(id enode.ID) *conn { + fd, _ := net.Pipe() + tx := newTestTransport(&trustedNode.PublicKey, fd, nil) + node := enode.SignNull(new(enr.Record), id) + return &conn{fd: fd, transport: tx, flags: inboundConn, node: node, cont: make(chan error)} + } + + // Create two connections from the same peer + peerID := randomID() + c1 := newconn(peerID) + c2 := newconn(peerID) + + // First connection enters pendingInbound + err1 := srv.checkpoint(c1, srv.checkpointPostHandshake) + if err1 != nil { + t.Fatalf("first connection failed unexpectedly: %v", err1) + } + + // Second connection should be rejected (duplicate pending inbound) + err2 := srv.checkpoint(c2, srv.checkpointPostHandshake) + if err2 != DiscAlreadyConnected { + t.Errorf("expected DiscAlreadyConnected for duplicate pending inbound, got: %v", err2) + } + + t.Logf("✅ First connection accepted, second rejected with: %v", err2) +} + +// TestServerPendingInboundCleanup checks that pending inbound state is properly +// cleaned up when a connection fails or completes. +func TestServerPendingInboundCleanup(t *testing.T) { + // Track when peers connect + connected := make(chan *Peer, 2) + remid := &newkey().PublicKey + + srv := startTestServer(t, remid, func(p *Peer) { + connected <- p + }) + defer close(connected) + defer srv.Stop() + + // First connection attempt - will succeed + conn1, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial: %v", err) + } + defer conn1.Close() + + // Wait for peer to connect + select { + case peer := <-connected: + t.Logf("First connection succeeded, peer: %s", peer.ID()) + + // Disconnect the peer to clean up pendingInbound + peer.Disconnect(DiscRequested) + time.Sleep(100 * time.Millisecond) + + // Verify peer is gone + if srv.PeerCount() != 0 { + t.Errorf("expected 0 peers after disconnect, got %d", srv.PeerCount()) + } + + case <-time.After(2 * time.Second): + t.Fatal("first connection did not complete within timeout") + } + + // Second connection attempt from same peer - should succeed now + // because pendingInbound was cleaned up + conn2, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) + if err != nil { + t.Fatalf("could not dial second time: %v", err) + } + defer conn2.Close() + + select { + case peer := <-connected: + t.Logf("Second connection succeeded after cleanup, peer: %s", peer.ID()) + if peer.ID() != enode.PubkeyToIDV4(remid) { + t.Errorf("peer has wrong id") + } + + case <-time.After(2 * time.Second): + t.Error("second connection did not complete within timeout - pendingInbound may not have been cleaned up") + } +} + type setupTransport struct { pubkey *ecdsa.PublicKey encHandshakeErr error