Skip to content

Commit c5112c3

Browse files
committed
fix(p2p/session): return err if peer tracker is empty
1 parent b80ef73 commit c5112c3

File tree

4 files changed

+47
-30
lines changed

4 files changed

+47
-30
lines changed

p2p/exchange.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,13 @@ func (ex *Exchange[H]) GetRangeByHeight(
292292
attribute.Int64("to", int64(to)),
293293
))
294294
defer span.End()
295-
session := newSession[H](
295+
session, err := newSession[H](
296296
ex.ctx, ex.host, ex.peerTracker, ex.protocolID, ex.Params.RequestTimeout, ex.metrics, withValidation(from),
297297
)
298+
// TODO(@vgonkivs): decide what to do with this error. Maybe we should fall into "discovery mode" and try to collect peers???
299+
if err != nil {
300+
return nil, err
301+
}
298302
defer session.close()
299303
// we request the next header height that we don't have: `fromHead`+1
300304
amount := to - (from.Height() + 1)

p2p/peer_tracker.go

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ import (
1616
)
1717

1818
type peerTracker struct {
19-
host host.Host
20-
connGater *conngater.BasicConnectionGater
21-
metrics *exchangeMetrics
2219
protocolID protocol.ID
23-
peerLk sync.RWMutex
20+
21+
host host.Host
22+
connGater *conngater.BasicConnectionGater
23+
24+
peerLk sync.RWMutex
2425
// trackedPeers contains active peers that we can request to.
2526
// we cache the peer once they disconnect,
2627
// so we can guarantee that peerQueue will only contain active peers
@@ -30,6 +31,8 @@ type peerTracker struct {
3031
// good peers during garbage collection
3132
pidstore PeerIDStore
3233

34+
metrics *exchangeMetrics
35+
3336
ctx context.Context
3437
cancel context.CancelFunc
3538
// done is used to gracefully stop the peerTracker.
@@ -135,11 +138,13 @@ func (p *peerTracker) track() {
135138
if network.NotConnected == ev.Connectedness {
136139
p.disconnected(ev.Peer)
137140
}
138-
case subscription := <-identifySub.Out():
139-
ev := subscription.(event.EvtPeerIdentificationCompleted)
140-
p.connected(ev.Peer)
141-
case subscription := <-protocolSub.Out():
142-
ev := subscription.(event.EvtPeerProtocolsUpdated)
141+
case identSubscription := <-identifySub.Out():
142+
ev := identSubscription.(event.EvtPeerIdentificationCompleted)
143+
if slices.Contains(ev.Protocols, p.protocolID) {
144+
p.connected(ev.Peer)
145+
}
146+
case protocolSubscription := <-protocolSub.Out():
147+
ev := protocolSubscription.(event.EvtPeerProtocolsUpdated)
143148
if slices.Contains(ev.Removed, p.protocolID) {
144149
p.disconnected(ev.Peer)
145150
break
@@ -173,15 +178,6 @@ func (p *peerTracker) connected(pID libpeer.ID) {
173178
return
174179
}
175180

176-
// check that peer supports our protocol id.
177-
protocol, err := p.host.Peerstore().SupportsProtocols(pID, p.protocolID)
178-
if err != nil {
179-
return
180-
}
181-
if !slices.Contains(protocol, p.protocolID) {
182-
return
183-
}
184-
185181
for _, c := range p.host.Network().ConnsToPeer(pID) {
186182
// check if connection is short-termed and skip this peer
187183
if c.Stat().Limited {

p2p/session.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,23 +60,28 @@ func newSession[H header.Header[H]](
6060
requestTimeout time.Duration,
6161
metrics *exchangeMetrics,
6262
options ...option[H],
63-
) *session[H] {
63+
) (*session[H], error) {
6464
ctx, cancel := context.WithCancel(ctx)
6565
ses := &session[H]{
6666
ctx: ctx,
6767
cancel: cancel,
6868
protocolID: protocolID,
6969
host: h,
70-
queue: newPeerQueue(ctx, peerTracker.peers()),
7170
peerTracker: peerTracker,
7271
requestTimeout: requestTimeout,
7372
metrics: metrics,
7473
}
7574

75+
peers := peerTracker.peers()
76+
if len(peers) == 0 {
77+
return nil, errors.New("empty peer tracker")
78+
}
79+
ses.queue = newPeerQueue(ctx, peers)
80+
7681
for _, opt := range options {
7782
opt(ses)
7883
}
79-
return ses
84+
return ses, nil
8085
}
8186

8287
// getRangeByHeight requests headers from different peers.

p2p/session_test.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"time"
77

88
"github.com/libp2p/go-libp2p/core/peer"
9+
blankhost "github.com/libp2p/go-libp2p/p2p/host/blank"
10+
swarm "github.com/libp2p/go-libp2p/p2p/net/swarm/testing"
911
"github.com/stretchr/testify/assert"
1012
"github.com/stretchr/testify/require"
1113

@@ -25,34 +27,44 @@ func Test_PrepareRequests(t *testing.T) {
2527
func Test_Validate(t *testing.T) {
2628
suite := headertest.NewTestSuite(t)
2729
head := suite.Head()
28-
ses := newSession(
30+
peerId := peer.ID("test")
31+
pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})}
32+
pT.trackedPeers[peerId] = struct{}{}
33+
pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t))
34+
ses, err := newSession(
2935
context.Background(),
3036
nil,
31-
&peerTracker{trackedPeers: make(map[peer.ID]struct{})},
37+
pT,
3238
"", time.Second, nil,
3339
withValidation(head),
3440
)
3541

42+
require.NoError(t, err)
3643
headers := suite.GenDummyHeaders(5)
37-
err := ses.verify(headers)
44+
err = ses.verify(headers)
3845
assert.NoError(t, err)
3946
}
4047

4148
// Test_ValidateFails ensures that non-adjacent range will return an error.
4249
func Test_ValidateFails(t *testing.T) {
4350
suite := headertest.NewTestSuite(t)
4451
head := suite.Head()
45-
ses := newSession(
52+
53+
peerId := peer.ID("test")
54+
pT := &peerTracker{trackedPeers: make(map[peer.ID]struct{})}
55+
pT.trackedPeers[peerId] = struct{}{}
56+
pT.host = blankhost.NewBlankHost(swarm.GenSwarm(t))
57+
ses, err := newSession(
4658
context.Background(),
47-
nil,
48-
&peerTracker{trackedPeers: make(map[peer.ID]struct{})},
59+
blankhost.NewBlankHost(swarm.GenSwarm(t)),
60+
pT,
4961
"", time.Second, nil,
5062
withValidation(head),
5163
)
52-
64+
require.NoError(t, err)
5365
headers := suite.GenDummyHeaders(5)
5466
// break adjacency
5567
headers[2] = headers[4]
56-
err := ses.verify(headers)
68+
err = ses.verify(headers)
5769
assert.Error(t, err)
5870
}

0 commit comments

Comments
 (0)