|
76 | 76 | errSelf = errors.New("is self") |
77 | 77 | errAlreadyDialing = errors.New("already dialing") |
78 | 78 | errAlreadyConnected = errors.New("already connected") |
| 79 | + errPendingInbound = errors.New("peer has pending inbound connection") |
79 | 80 | errRecentlyDialed = errors.New("recently dialed") |
80 | 81 | errNetRestrict = errors.New("not contained in netrestrict list") |
81 | 82 | errNoPort = errors.New("node does not provide TCP port") |
@@ -104,12 +105,15 @@ type dialScheduler struct { |
104 | 105 | remStaticCh chan *enode.Node |
105 | 106 | addPeerCh chan *conn |
106 | 107 | remPeerCh chan *conn |
| 108 | + addPendingCh chan enode.ID |
| 109 | + remPendingCh chan enode.ID |
107 | 110 |
|
108 | 111 | // Everything below here belongs to loop and |
109 | 112 | // should only be accessed by code on the loop goroutine. |
110 | | - dialing map[enode.ID]*dialTask // active tasks |
111 | | - peers map[enode.ID]struct{} // all connected peers |
112 | | - dialPeers int // current number of dialed peers |
| 113 | + dialing map[enode.ID]*dialTask // active tasks |
| 114 | + peers map[enode.ID]struct{} // all connected peers |
| 115 | + pendingInbound map[enode.ID]struct{} // in-progress inbound connections |
| 116 | + dialPeers int // current number of dialed peers |
113 | 117 |
|
114 | 118 | // The static map tracks all static dial tasks. The subset of usable static dial tasks |
115 | 119 | // (i.e. those passing checkDial) is kept in staticPool. The scheduler prefers |
@@ -163,19 +167,22 @@ func (cfg dialConfig) withDefaults() dialConfig { |
163 | 167 | func newDialScheduler(config dialConfig, it enode.Iterator, setupFunc dialSetupFunc) *dialScheduler { |
164 | 168 | cfg := config.withDefaults() |
165 | 169 | d := &dialScheduler{ |
166 | | - dialConfig: cfg, |
167 | | - historyTimer: mclock.NewAlarm(cfg.clock), |
168 | | - setupFunc: setupFunc, |
169 | | - dnsLookupFunc: net.DefaultResolver.LookupNetIP, |
170 | | - dialing: make(map[enode.ID]*dialTask), |
171 | | - static: make(map[enode.ID]*dialTask), |
172 | | - peers: make(map[enode.ID]struct{}), |
173 | | - doneCh: make(chan *dialTask), |
174 | | - nodesIn: make(chan *enode.Node), |
175 | | - addStaticCh: make(chan *enode.Node), |
176 | | - remStaticCh: make(chan *enode.Node), |
177 | | - addPeerCh: make(chan *conn), |
178 | | - remPeerCh: make(chan *conn), |
| 170 | + dialConfig: cfg, |
| 171 | + historyTimer: mclock.NewAlarm(cfg.clock), |
| 172 | + setupFunc: setupFunc, |
| 173 | + dnsLookupFunc: net.DefaultResolver.LookupNetIP, |
| 174 | + dialing: make(map[enode.ID]*dialTask), |
| 175 | + static: make(map[enode.ID]*dialTask), |
| 176 | + peers: make(map[enode.ID]struct{}), |
| 177 | + pendingInbound: make(map[enode.ID]struct{}), |
| 178 | + doneCh: make(chan *dialTask), |
| 179 | + nodesIn: make(chan *enode.Node), |
| 180 | + addStaticCh: make(chan *enode.Node), |
| 181 | + remStaticCh: make(chan *enode.Node), |
| 182 | + addPeerCh: make(chan *conn), |
| 183 | + remPeerCh: make(chan *conn), |
| 184 | + addPendingCh: make(chan enode.ID), |
| 185 | + remPendingCh: make(chan enode.ID), |
179 | 186 | } |
180 | 187 | d.lastStatsLog = d.clock.Now() |
181 | 188 | d.ctx, d.cancel = context.WithCancel(context.Background()) |
@@ -223,6 +230,22 @@ func (d *dialScheduler) peerRemoved(c *conn) { |
223 | 230 | } |
224 | 231 | } |
225 | 232 |
|
| 233 | +// inboundPending notifies the scheduler about a pending inbound connection. |
| 234 | +func (d *dialScheduler) inboundPending(id enode.ID) { |
| 235 | + select { |
| 236 | + case d.addPendingCh <- id: |
| 237 | + case <-d.ctx.Done(): |
| 238 | + } |
| 239 | +} |
| 240 | + |
| 241 | +// inboundCompleted notifies the scheduler that an inbound connection completed or failed. |
| 242 | +func (d *dialScheduler) inboundCompleted(id enode.ID) { |
| 243 | + select { |
| 244 | + case d.remPendingCh <- id: |
| 245 | + case <-d.ctx.Done(): |
| 246 | + } |
| 247 | +} |
| 248 | + |
226 | 249 | // loop is the main loop of the dialer. |
227 | 250 | func (d *dialScheduler) loop(it enode.Iterator) { |
228 | 251 | var ( |
@@ -276,6 +299,15 @@ loop: |
276 | 299 | delete(d.peers, c.node.ID()) |
277 | 300 | d.updateStaticPool(c.node.ID()) |
278 | 301 |
|
| 302 | + case id := <-d.addPendingCh: |
| 303 | + d.pendingInbound[id] = struct{}{} |
| 304 | + d.log.Trace("Marked node as pending inbound", "id", id) |
| 305 | + |
| 306 | + case id := <-d.remPendingCh: |
| 307 | + delete(d.pendingInbound, id) |
| 308 | + d.updateStaticPool(id) |
| 309 | + d.log.Trace("Unmarked node as pending inbound", "id", id) |
| 310 | + |
279 | 311 | case node := <-d.addStaticCh: |
280 | 312 | id := node.ID() |
281 | 313 | _, exists := d.static[id] |
@@ -390,6 +422,9 @@ func (d *dialScheduler) checkDial(n *enode.Node) error { |
390 | 422 | if _, ok := d.peers[n.ID()]; ok { |
391 | 423 | return errAlreadyConnected |
392 | 424 | } |
| 425 | + if _, ok := d.pendingInbound[n.ID()]; ok { |
| 426 | + return errPendingInbound |
| 427 | + } |
393 | 428 | if d.netRestrict != nil && !d.netRestrict.ContainsAddr(n.IPAddr()) { |
394 | 429 | return errNetRestrict |
395 | 430 | } |
|
0 commit comments