Skip to content

Commit c88c554

Browse files
committed
fix(p2p/session): return err if peer tracker is empty
1 parent b80ef73 commit c88c554

File tree

6 files changed

+84
-59
lines changed

6 files changed

+84
-59
lines changed

p2p/exchange.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ func (ex *Exchange[H]) Head(ctx context.Context, opts ...header.HeadOption[H]) (
150150
// their Head and verify against the given trusted header.
151151
useTrackedPeers := !reqParams.TrustedHead.IsZero()
152152
if useTrackedPeers {
153-
trackedPeers := ex.peerTracker.getPeers(maxUntrustedHeadRequests)
153+
trackedPeers := ex.peerTracker.peers(maxUntrustedHeadRequests)
154154
if len(trackedPeers) > 0 {
155-
peers = trackedPeers
155+
peers = transform(trackedPeers, func(p *peerStat) peer.ID {
156+
return p.peerID
157+
})
156158
log.Debugw("requesting head from tracked peers", "amount", len(peers))
157159
}
158160
}
@@ -292,9 +294,13 @@ func (ex *Exchange[H]) GetRangeByHeight(
292294
attribute.Int64("to", int64(to)),
293295
))
294296
defer span.End()
295-
session := newSession[H](
297+
session, err := newSession[H](
296298
ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from),
297299
)
300+
// TODO(@vgonkivs): decide what to do with this error. Maybe we should fall into "discovery mode" and try to collect peers???
301+
if err != nil {
302+
return nil, err
303+
}
298304
defer session.close()
299305
// we request the next header height that we don't have: `fromHead`+1
300306
amount := to - (from.Height() + 1)

p2p/helpers.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,13 @@ func convertStatusCodeToError(code p2p_pb.StatusCode) error {
124124
return fmt.Errorf("unknown status code %d", code)
125125
}
126126
}
127+
128+
// transform applies a provided function to each element of the input slice,
129+
// producing a new slice with the results of the function.
130+
func transform[T, U any](ts []T, f func(T) U) []U {
131+
us := make([]U, len(ts))
132+
for i := range ts {
133+
us[i] = f(ts[i])
134+
}
135+
return us
136+
}

p2p/peer_tracker.go

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@ import (
1616
)
1717

1818
type peerTracker struct {
19-
host host.Host
20-
connGater *conngater.BasicConnectionGater
21-
metrics *exchangeMetrics
2219
protocolID protocol.ID
23-
peerLk sync.RWMutex
20+
21+
host host.Host
22+
connGater *conngater.BasicConnectionGater
23+
24+
peerLk sync.RWMutex
2425
// trackedPeers contains active peers that we can request to.
25-
// we cache the peer once they disconnect,
2626
// so we can guarantee that peerQueue will only contain active peers
2727
trackedPeers map[libpeer.ID]struct{}
2828

2929
// an optional interface used to periodically dump
3030
// good peers during garbage collection
3131
pidstore PeerIDStore
3232

33+
metrics *exchangeMetrics
34+
3335
ctx context.Context
3436
cancel context.CancelFunc
3537
// done is used to gracefully stop the peerTracker.
@@ -103,19 +105,20 @@ func (p *peerTracker) track() {
103105
p.done <- struct{}{}
104106
}()
105107

106-
connSubs, err := p.host.EventBus().Subscribe(&event.EvtPeerConnectednessChanged{})
108+
evtBus := p.host.EventBus()
109+
connSubs, err := evtBus.Subscribe(&event.EvtPeerConnectednessChanged{})
107110
if err != nil {
108111
log.Errorw("subscribing to EvtPeerConnectednessChanged", "err", err)
109112
return
110113
}
111114

112-
identifySub, err := p.host.EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{})
115+
identifySub, err := evtBus.Subscribe(&event.EvtPeerIdentificationCompleted{})
113116
if err != nil {
114117
log.Errorw("subscribing to EvtPeerIdentificationCompleted", "err", err)
115118
return
116119
}
117120

118-
protocolSub, err := p.host.EventBus().Subscribe(&event.EvtPeerProtocolsUpdated{})
121+
protocolSub, err := evtBus.Subscribe(&event.EvtPeerProtocolsUpdated{})
119122
if err != nil {
120123
log.Errorw("subscribing to EvtPeerProtocolsUpdated", "err", err)
121124
return
@@ -124,9 +127,7 @@ func (p *peerTracker) track() {
124127
for {
125128
select {
126129
case <-p.ctx.Done():
127-
err = connSubs.Close()
128-
errors.Join(err, identifySub.Close(), protocolSub.Close())
129-
if err != nil {
130+
if err := closeSubscriptions(connSubs, identifySub, protocolSub); err != nil {
130131
log.Errorw("closing subscriptions", "err", err)
131132
}
132133
return
@@ -135,35 +136,23 @@ func (p *peerTracker) track() {
135136
if network.NotConnected == ev.Connectedness {
136137
p.disconnected(ev.Peer)
137138
}
138-
case subscription := <-identifySub.Out():
139-
ev := subscription.(event.EvtPeerIdentificationCompleted)
140-
p.connected(ev.Peer)
141-
case subscription := <-protocolSub.Out():
142-
ev := subscription.(event.EvtPeerProtocolsUpdated)
139+
case identSubscription := <-identifySub.Out():
140+
ev := identSubscription.(event.EvtPeerIdentificationCompleted)
141+
if slices.Contains(ev.Protocols, p.protocolID) {
142+
p.connected(ev.Peer)
143+
}
144+
case protocolSubscription := <-protocolSub.Out():
145+
ev := protocolSubscription.(event.EvtPeerProtocolsUpdated)
143146
if slices.Contains(ev.Removed, p.protocolID) {
144147
p.disconnected(ev.Peer)
145-
break
146148
}
147-
p.connected(ev.Peer)
149+
if slices.Contains(ev.Added, p.protocolID) {
150+
p.connected(ev.Peer)
151+
}
148152
}
149153
}
150154
}
151155

152-
// getPeers returns the tracker's currently tracked peers up to the `max`.
153-
func (p *peerTracker) getPeers(max int) []libpeer.ID {
154-
p.peerLk.RLock()
155-
defer p.peerLk.RUnlock()
156-
157-
peers := make([]libpeer.ID, 0, max)
158-
for peer := range p.trackedPeers {
159-
peers = append(peers, peer)
160-
if len(peers) == max {
161-
break
162-
}
163-
}
164-
return peers
165-
}
166-
167156
func (p *peerTracker) connected(pID libpeer.ID) {
168157
if err := pID.Validate(); err != nil {
169158
return
@@ -173,15 +162,6 @@ func (p *peerTracker) connected(pID libpeer.ID) {
173162
return
174163
}
175164

176-
// check that peer supports our protocol id.
177-
protocol, err := p.host.Peerstore().SupportsProtocols(pID, p.protocolID)
178-
if err != nil {
179-
return
180-
}
181-
if !slices.Contains(protocol, p.protocolID) {
182-
return
183-
}
184-
185165
for _, c := range p.host.Network().ConnsToPeer(pID) {
186166
// check if connection is short-termed and skip this peer
187167
if c.Stat().Limited {
@@ -219,17 +199,21 @@ func (p *peerTracker) disconnected(pID libpeer.ID) {
219199
p.metrics.peersDisconnected(1)
220200
}
221201

222-
func (p *peerTracker) peers() []*peerStat {
202+
// peers returns the tracker's currently tracked peers up to the `max`.
203+
func (p *peerTracker) peers(max int) []*peerStat {
223204
p.peerLk.RLock()
224205
defer p.peerLk.RUnlock()
225206

226-
peers := make([]*peerStat, 0)
207+
peers := make([]*peerStat, 0, max)
227208
for peerID := range p.trackedPeers {
228209
score := 0
229210
if info := p.host.ConnManager().GetTagInfo(peerID); info != nil {
230211
score = info.Tags[string(p.protocolID)]
231212
}
232213
peers = append(peers, &peerStat{peerID: peerID, peerScore: score})
214+
if len(peers) == max {
215+
break
216+
}
233217
}
234218
return peers
235219
}
@@ -300,3 +284,11 @@ func (p *peerTracker) updateScore(stats *peerStat, size uint64, duration time.Du
300284
score := stats.updateStats(size, duration)
301285
p.host.ConnManager().TagPeer(stats.peerID, string(p.protocolID), score)
302286
}
287+
288+
func closeSubscriptions(subs ...event.Subscription) error {
289+
var err error
290+
for _, sub := range subs {
291+
err = errors.Join(err, sub.Close())
292+
}
293+
return err
294+
}

p2p/peer_tracker_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func TestPeerTracker_Bootstrap(t *testing.T) {
6262
require.NoError(t, err)
6363

6464
assert.Eventually(t, func() bool {
65-
return len(tracker.getPeers(7)) > 0
65+
return len(tracker.peers(7)) > 0
6666
}, time.Millisecond*500, time.Millisecond*100)
6767
}
6868

p2p/session.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,28 @@ func newSession[H header.Header[H]](
6060
requestTimeout time.Duration,
6161
metrics *exchangeMetrics,
6262
options ...option[H],
63-
) *session[H] {
63+
) (*session[H], error) {
6464
ctx, cancel := context.WithCancel(ctx)
6565
ses := &session[H]{
6666
ctx: ctx,
6767
cancel: cancel,
6868
protocolID: protocolID,
6969
host: h,
70-
queue: newPeerQueue(ctx, peerTracker.peers()),
7170
peerTracker: peerTracker,
7271
requestTimeout: requestTimeout,
7372
metrics: metrics,
7473
}
7574

75+
peers := peerTracker.peers(len(peerTracker.trackedPeers))
76+
if len(peers) == 0 {
77+
return nil, errors.New("empty peer tracker")
78+
}
79+
ses.queue = newPeerQueue(ctx, peers)
80+
7681
for _, opt := range options {
7782
opt(ses)
7883
}
79-
return ses
84+
return ses, nil
8085
}
8186

8287
// getRangeByHeight requests headers from different peers.

p2p/session_test.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"time"
77

88
"github.com/libp2p/go-libp2p/core/peer"
9+
blankhost "github.com/libp2p/go-libp2p/p2p/host/blank"
10+
swarm "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
911
"github.com/stretchr/testify/assert"
1012
"github.com/stretchr/testify/require"
1113

@@ -25,34 +27,44 @@ func Test_PrepareRequests(t *testing.T) {
2527
func Test_Validate(t *testing.T) {
2628
suite := headertest.NewTestSuite(t)
2729
head := suite.Head()
28-
ses := newSession(
30+
peerId := peer.ID("test")
31+
pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})}
32+
pT.trackedPeers[peerId] = struct{}{}
33+
pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t))
34+
ses, err := newSession(
2935
context.Background(),
3036
nil,
31-
&peerTracker{trackedPeers: make(map[peer.ID]struct{})},
37+
pT,
3238
"", time.Second, nil,
3339
withValidation(head),
3440
)
3541

42+
require.NoError(t, err)
3643
headers := suite.GenDummyHeaders(5)
37-
err := ses.verify(headers)
44+
err = ses.verify(headers)
3845
assert.NoError(t, err)
3946
}
4047

4148
// Test_ValidateFails ensures that non-adjacent range will return an error.
4249
func Test_ValidateFails(t *testing.T) {
4350
suite := headertest.NewTestSuite(t)
4451
head := suite.Head()
45-
ses := newSession(
52+
53+
peerId := peer.ID("test")
54+
pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})}
55+
pT.trackedPeers[peerId] = struct{}{}
56+
pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t))
57+
ses, err := newSession(
4658
context.Background(),
47-
nil,
48-
&peerTracker{trackedPeers: make(map[peer.ID]struct{})},
59+
blankhost.NewBlankHost(swarm.GenSwarm(t)),
60+
pT,
4961
"", time.Second, nil,
5062
withValidation(head),
5163
)
52-
64+
require.NoError(t, err)
5365
headers := suite.GenDummyHeaders(5)
5466
// break adjacency
5567
headers[2] = headers[4]
56-
err := ses.verify(headers)
68+
err = ses.verify(headers)
5769
assert.Error(t, err)
5870
}

0 commit comments

Comments
 (0)