Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions p2p/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ var (
errSelf = errors.New("is self")
errAlreadyDialing = errors.New("already dialing")
errAlreadyConnected = errors.New("already connected")
errPendingInbound = errors.New("peer has pending inbound connection")
errRecentlyDialed = errors.New("recently dialed")
errNetRestrict = errors.New("not contained in netrestrict list")
errNoPort = errors.New("node does not provide TCP port")
Expand Down Expand Up @@ -104,12 +105,15 @@ type dialScheduler struct {
remStaticCh chan *enode.Node
addPeerCh chan *conn
remPeerCh chan *conn
addPendingCh chan enode.ID
remPendingCh chan enode.ID

// Everything below here belongs to loop and
// should only be accessed by code on the loop goroutine.
dialing map[enode.ID]*dialTask // active tasks
peers map[enode.ID]struct{} // all connected peers
dialPeers int // current number of dialed peers
dialing map[enode.ID]*dialTask // active tasks
peers map[enode.ID]struct{} // all connected peers
pendingInbound map[enode.ID]struct{} // in-progress inbound connections
dialPeers int // current number of dialed peers

// The static map tracks all static dial tasks. The subset of usable static dial tasks
// (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers
Expand Down Expand Up @@ -163,19 +167,22 @@ func (cfg dialConfig) withDefaults() dialConfig {
func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler {
cfg := config.withDefaults()
d := &dialScheduler{
dialConfig: cfg,
historyTimer: mclock.NewAlarm(cfg.clock),
setupFunc: setupFunc,
dnsLookupFunc: net.DefaultResolver.LookupNetIP,
dialing: make(map[enode.ID]*dialTask),
static: make(map[enode.ID]*dialTask),
peers: make(map[enode.ID]struct{}),
doneCh: make(chan *dialTask),
nodesIn: make(chan *enode.Node),
addStaticCh: make(chan *enode.Node),
remStaticCh: make(chan *enode.Node),
addPeerCh: make(chan *conn),
remPeerCh: make(chan *conn),
dialConfig: cfg,
historyTimer: mclock.NewAlarm(cfg.clock),
setupFunc: setupFunc,
dnsLookupFunc: net.DefaultResolver.LookupNetIP,
dialing: make(map[enode.ID]*dialTask),
static: make(map[enode.ID]*dialTask),
peers: make(map[enode.ID]struct{}),
pendingInbound: make(map[enode.ID]struct{}),
doneCh: make(chan *dialTask),
nodesIn: make(chan *enode.Node),
addStaticCh: make(chan *enode.Node),
remStaticCh: make(chan *enode.Node),
addPeerCh: make(chan *conn),
remPeerCh: make(chan *conn),
addPendingCh: make(chan enode.ID),
remPendingCh: make(chan enode.ID),
}
d.lastStatsLog = d.clock.Now()
d.ctx, d.cancel = context.WithCancel(context.Background())
Expand Down Expand Up @@ -223,6 +230,22 @@ func (d *dialScheduler) peerRemoved(c *conn) {
}
}

// inboundPending notifies the scheduler about a pending inbound connection.
func (d *dialScheduler) inboundPending(id enode.ID) {
select {
case d.addPendingCh <- id:
case <-d.ctx.Done():
}
}

// inboundCompleted notifies the scheduler that an inbound connection completed or failed.
func (d *dialScheduler) inboundCompleted(id enode.ID) {
select {
case d.remPendingCh <- id:
case <-d.ctx.Done():
}
}

// loop is the main loop of the dialer.
func (d *dialScheduler) loop(it enode.Iterator) {
var (
Expand Down Expand Up @@ -276,6 +299,15 @@ loop:
delete(d.peers, c.node.ID())
d.updateStaticPool(c.node.ID())

case id := <-d.addPendingCh:
d.pendingInbound[id] = struct{}{}
d.log.Trace("Marked node as pending inbound", "id", id)

case id := <-d.remPendingCh:
delete(d.pendingInbound, id)
d.updateStaticPool(id)
d.log.Trace("Unmarked node as pending inbound", "id", id)

case node := <-d.addStaticCh:
id := node.ID()
_, exists := d.static[id]
Expand Down Expand Up @@ -390,6 +422,9 @@ func (d *dialScheduler) checkDial(n *enode.Node) error {
if _, ok := d.peers[n.ID()]; ok {
return errAlreadyConnected
}
if _, ok := d.pendingInbound[n.ID()]; ok {
return errPendingInbound
}
if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) {
return errNetRestrict
}
Expand Down
76 changes: 76 additions & 0 deletions p2p/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,82 @@ func TestDialSchedDNSHostname(t *testing.T) {
})
}

// This test checks that nodes with pending inbound connections are not dialed.
func TestDialSchedPendingInbound(t *testing.T) {
t.Parallel()

config := dialConfig{
maxActiveDials: 5,
maxDialPeers: 4,
}
runDialTest(t, config, []dialTestRound{
// 2 peers are connected, leaving 2 dial slots.
// Node 0x03 has a pending inbound connection.
// Discovered nodes 0x03, 0x04, 0x05 but only 0x04 and 0x05 should be dialed.
{
peersAdded: []*conn{
{flags: dynDialedConn, node: newNode(uintID(0x01), "127.0.0.1:30303")},
{flags: dynDialedConn, node: newNode(uintID(0x02), "127.0.0.2:30303")},
},
update: func(d *dialScheduler) {
d.inboundPending(uintID(0x03))
},
discovered: []*enode.Node{
newNode(uintID(0x03), "127.0.0.3:30303"), // not dialed because pending inbound
newNode(uintID(0x04), "127.0.0.4:30303"),
newNode(uintID(0x05), "127.0.0.5:30303"),
},
wantNewDials: []*enode.Node{
newNode(uintID(0x04), "127.0.0.4:30303"),
newNode(uintID(0x05), "127.0.0.5:30303"),
},
},
// Pending inbound connection for 0x03 completes successfully.
// Node 0x03 becomes a connected peer.
// One dial slot remains, node 0x06 is dialed.
{
update: func(d *dialScheduler) {
// Pending inbound completes
d.inboundCompleted(uintID(0x03))
},
peersAdded: []*conn{
{flags: inboundConn, node: newNode(uintID(0x03), "127.0.0.3:30303")},
},
succeeded: []enode.ID{
uintID(0x04),
},
failed: []enode.ID{
uintID(0x05),
},
discovered: []*enode.Node{
newNode(uintID(0x03), "127.0.0.3:30303"), // not dialed, now connected
newNode(uintID(0x06), "127.0.0.6:30303"),
},
wantNewDials: []*enode.Node{
newNode(uintID(0x06), "127.0.0.6:30303"),
},
},
// Inbound peer 0x03 disconnects.
// Another pending inbound starts for 0x07.
// Only 0x03 should be dialed, not 0x07.
{
peersRemoved: []enode.ID{
uintID(0x03),
},
update: func(d *dialScheduler) {
d.inboundPending(uintID(0x07))
},
discovered: []*enode.Node{
newNode(uintID(0x03), "127.0.0.3:30303"),
newNode(uintID(0x07), "127.0.0.7:30303"), // not dialed because pending inbound
},
wantNewDials: []*enode.Node{
newNode(uintID(0x03), "127.0.0.3:30303"),
},
},
})
}

// -------
// Code below here is the framework for the tests above.

Expand Down
86 changes: 76 additions & 10 deletions p2p/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -639,10 +639,12 @@
defer srv.dialsched.stop()

var (
peers = make(map[enode.ID]*Peer)
inboundCount = 0
trusted = make(map[enode.ID]bool, len(srv.TrustedNodes))
peers = make(map[enode.ID]*Peer)
inboundCount = 0
trusted = make(map[enode.ID]bool, len(srv.TrustedNodes))
pendingInbound = make(map[enode.ID]time.Time) // Track in-progress inbound connections
)

// Put trusted nodes into a map to speed up checks.
// Trusted peers are loaded on startup or added via AddTrustedPeer RPC.
for _, n := range srv.TrustedNodes {
Expand Down Expand Up @@ -682,22 +684,65 @@
case c := <-srv.checkpointPostHandshake:
// A connection has passed the encryption handshake so
// the remote identity is known (but hasn't been verified yet).
if trusted[c.node.ID()] {
nodeID := c.node.ID()
if trusted[nodeID] {
// Ensure that the trusted flag is set before checking against MaxPeers.
c.flags |= trustedConn
}
// TODO: track in-progress inbound node IDs (pre-Peer) to avoid dialing them.
c.cont <- srv.postHandshakeChecks(peers, inboundCount, c)

// Check for duplicate connections: both active peers and in-progress inbound connections.
if c.flags&inboundConn != 0 && c.flags&trustedConn == 0 {
// Check if we already have this peer or if there's already an in-progress inbound connection from them.
if _, exists := peers[nodeID]; exists {
srv.log.Debug("Rejecting duplicate inbound connection (peer exists)", "id", nodeID)
c.cont <- DiscAlreadyConnected
continue running
}

if startTime, exists := pendingInbound[nodeID]; exists {
srv.log.Debug("Rejecting duplicate inbound connection (already pending)",
"id", nodeID, "pending_duration", time.Since(startTime))
c.cont <- DiscAlreadyConnected
continue running

}

Check failure on line 708 in p2p/server.go

View workflow job for this annotation

GitHub Actions / Lint

unnecessary trailing newline (whitespace)

pendingInbound[nodeID] = time.Now()
srv.dialsched.inboundPending(nodeID)
srv.log.Trace("Tracking pending inbound connection", "id", nodeID, "pending_count", len(pendingInbound))
}

err := srv.postHandshakeChecks(peers, inboundCount, c)
if err != nil && c.flags&inboundConn != 0 && c.flags&trustedConn == 0 {
delete(pendingInbound, nodeID)
srv.dialsched.inboundCompleted(nodeID)
srv.log.Trace("Removed failed pending inbound connection", "id", nodeID, "err", err)
}
c.cont <- err

case c := <-srv.checkpointAddPeer:
// At this point the connection is past the protocol handshake.
// Its capabilities are known and the remote identity is verified.
nodeID := c.node.ID()
err := srv.addPeerChecks(peers, inboundCount, c)
if err == nil {
// The handshakes are done and it passed all checks.
p := srv.launchPeer(c)
peers[c.node.ID()] = p
srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(), "conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name())
peers[nodeID] = p
// Remove from pending tracker as it became promoted to proper peer
if c.flags&inboundConn != 0 {
if startTime, exists := pendingInbound[nodeID]; exists {
duration := time.Since(startTime)
delete(pendingInbound, nodeID)
srv.dialsched.inboundCompleted(nodeID)
srv.log.Trace("Promoted pending inbound to peer", "id", nodeID,
"handshake_duration", duration, "pending_count", len(pendingInbound))
}

}

Check failure on line 742 in p2p/server.go

View workflow job for this annotation

GitHub Actions / Lint

unnecessary trailing newline (whitespace)

srv.log.Debug("Adding p2p peer", "peercount", len(peers), "id", p.ID(),
"conn", c.flags, "addr", p.RemoteAddr(), "name", p.Name())
srv.dialsched.peerAdded(c)
if p.Inbound() {
inboundCount++
Expand All @@ -708,14 +753,34 @@
activeOutboundPeerGauge.Inc(1)
}
activePeerGauge.Inc(1)

} else {

Check failure on line 757 in p2p/server.go

View workflow job for this annotation

GitHub Actions / Lint

unnecessary trailing newline (whitespace)
// Failed to add peer. Clean up pending tracking if it was inbound.
if c.flags&inboundConn != 0 {
delete(pendingInbound, nodeID)
srv.dialsched.inboundCompleted(nodeID)
srv.log.Trace("Removed failed pending inbound at add peer stage",
"id", nodeID, "err", err)
}

}

c.cont <- err

case pd := <-srv.delpeer:
// A peer disconnected.
d := common.PrettyDuration(mclock.Now() - pd.created)
delete(peers, pd.ID())
srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", pd.ID(), "duration", d, "req", pd.requested, "err", pd.err)
nodeID := pd.ID()
delete(peers, nodeID)
// Remove from pending tracking if present (defensive cleanup).
if _, exists := pendingInbound[nodeID]; exists {
delete(pendingInbound, nodeID)
srv.dialsched.inboundCompleted(nodeID)
srv.log.Trace("Cleaned up pending entry on peer deletion", "id", nodeID)
}

srv.log.Debug("Removing p2p peer", "peercount", len(peers), "id", nodeID,
"duration", d, "req", pd.requested, "err", pd.err)
srv.dialsched.peerRemoved(pd.rw)
if pd.Inbound() {
inboundCount--
Expand Down Expand Up @@ -747,6 +812,7 @@
p := <-srv.delpeer
p.log.Trace("<-delpeer (spindown)")
delete(peers, p.ID())
delete(pendingInbound, p.ID())
}
}

Expand Down
Loading
Loading