Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion gossipsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,10 @@ func WithDirectPeers(pis []peer.AddrInfo) Option {
gs.direct = direct

if gs.tagTracer != nil {
gs.tagTracer.direct = direct
gs.tagTracer.isDirect = func(p peer.ID) bool {
_, ok := gs.direct[p]
return ok
}
}

return nil
Expand Down Expand Up @@ -827,6 +830,20 @@ func (gs *GossipSubRouter) EnoughPeers(topic string, suggested int) bool {
return false
}

func (gs *GossipSubRouter) AddDirectPeer(pi peer.AddrInfo) {
if gs.direct == nil {
gs.direct = make(map[peer.ID]struct{})
}
gs.direct[pi.ID] = struct{}{}
gs.p.host.Peerstore().AddAddrs(pi.ID, pi.Addrs, peerstore.PermanentAddrTTL)
gs.tagTracer.protectDirect(pi.ID)
}

func (gs *GossipSubRouter) RemoveDirectPeer(p peer.ID) {
delete(gs.direct, p)
gs.tagTracer.unprotectDirect(p)
}

func (gs *GossipSubRouter) AcceptFrom(p peer.ID) AcceptStatus {
_, direct := gs.direct[p]
if direct {
Expand Down
44 changes: 44 additions & 0 deletions gossipsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,50 @@ func TestGossipsubDirectPeers(t *testing.T) {
}
}

func TestGossipsubDynamicDirectPeers(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

h := getDefaultHosts(t, 3)
psubs := []*PubSub{
getGossipsub(ctx, h[0], WithDirectConnectTicks(2)),
getGossipsub(ctx, h[1], WithDirectConnectTicks(2)),
getGossipsub(ctx, h[2], WithDirectConnectTicks(2)),
}

listDirectPeers := func(psb *PubSub) int {
directPeers := 0
gspRt, _ := psb.rt.(*GossipSubRouter)
fn := func() {
directPeers = len(gspRt.direct)
}
psb.syncEval(fn)
return directPeers
}

// test dinamic addition of direct-peers to h[1] and h[2]
psubs[1].AddDirectPeer(peer.AddrInfo{ID: h[2].ID(), Addrs: h[2].Addrs()})
psubs[2].AddDirectPeer(peer.AddrInfo{ID: h[1].ID(), Addrs: h[1].Addrs()})

// give enough time to the state machine to process the direct additions
time.Sleep(time.Second)

if listDirectPeers(psubs[1]) < 1 || listDirectPeers(psubs[2]) < 1 {
t.Fatal("expected 1 direct peer at both gsp rts")
}

// remove peer from direct from directPeers
psubs[1].RemoveDirectPeer(h[2].ID())
psubs[2].RemoveDirectPeer(h[1].ID())

// give enough time to the state machine to process the direct additions
time.Sleep(time.Second)

if listDirectPeers(psubs[1]) > 0 || listDirectPeers(psubs[2]) > 0 {
t.Fatal("expected no direct peers both gsp rts")
}
}

func TestGossipSubPeerFilter(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down
41 changes: 41 additions & 0 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -1909,3 +1909,44 @@ type addRelayReq struct {
topic string
resp chan RelayCancelFunc
}

func (p *PubSub) syncEval(f func()) error {
done := make(chan struct{})
syncFn := func() {
defer close(done)
f()
}
select {
case p.eval <- syncFn:
select {
case <-done:
case <-p.ctx.Done():
return p.ctx.Err()
}
case <-p.ctx.Done():
return p.ctx.Err()
}
return nil
}

// AddDirectPeer tags the peer as a direct peer at the internal router
func (p *PubSub) AddDirectPeer(pInfo peer.AddrInfo) error {
gs, ok := p.rt.(*GossipSubRouter)
if !ok {
return errors.New("add direct peer only supported by gossipsub")
}
return p.syncEval(func() {
gs.AddDirectPeer(pInfo)
})
}

// RemoveDirectPeer un-tags the peer from being direct peer at the internal router
func (p *PubSub) RemoveDirectPeer(pid peer.ID) error {
gs, ok := p.rt.(*GossipSubRouter)
if !ok {
return errors.New("remove direct peer only supported by gossipsub")
}
return p.syncEval(func() {
gs.RemoveDirectPeer(pid)
})
}
31 changes: 16 additions & 15 deletions tag_tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type tagTracer struct {
idGen *msgIDGenerator
decayer connmgr.Decayer
decaying map[string]connmgr.DecayingTag
direct map[peer.ID]struct{}
isDirect func(p peer.ID) bool

// a map of message ids to the set of peers who delivered the message after the first delivery,
// but before the message was finished validating
Expand All @@ -71,7 +71,7 @@ func newTagTracer(cmgr connmgr.ConnManager) *tagTracer {
decayer: decayer,
decaying: make(map[string]connmgr.DecayingTag),
nearFirst: make(map[string]map[peer.ID]struct{}),
direct: make(map[peer.ID]struct{}),
isDirect: func(p peer.ID) bool { return false },
logger: logger, // Overridden in Start
}
}
Expand All @@ -83,18 +83,9 @@ func (t *tagTracer) Start(gs *GossipSubRouter, logger *slog.Logger) {
t.logger = logger

t.idGen = gs.p.idGen
t.direct = gs.direct
}

func (t *tagTracer) tagPeerIfDirect(p peer.ID) {
if t.direct == nil {
return
}

// tag peer if it is a direct peer
_, direct := t.direct[p]
if direct {
t.cmgr.Protect(p, "pubsub:<direct>")
t.isDirect = func(p peer.ID) bool {
_, ok := gs.direct[p]
return ok
}
}

Expand Down Expand Up @@ -181,11 +172,21 @@ func (t *tagTracer) nearFirstPeers(msg *Message) []peer.ID {
return peers
}

func (t *tagTracer) protectDirect(p peer.ID) {
t.cmgr.Protect(p, "pubsub:<direct>")
}

func (t *tagTracer) unprotectDirect(p peer.ID) {
t.cmgr.Unprotect(p, "pubsub:<direct>")
}

// -- RawTracer interface methods
var _ RawTracer = (*tagTracer)(nil)

func (t *tagTracer) AddPeer(p peer.ID, proto protocol.ID) {
t.tagPeerIfDirect(p)
if t.isDirect(p) {
t.protectDirect(p)
}
}

func (t *tagTracer) Join(topic string) {
Expand Down
9 changes: 6 additions & 3 deletions tag_tracer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ func TestTagTracerDirectPeerTags(t *testing.T) {
p2 := peer.ID("2")
p3 := peer.ID("3")

// in the real world, tagTracer.direct is set in the WithDirectPeers option function
tt.direct = make(map[peer.ID]struct{})
tt.direct[p1] = struct{}{}
tt.protectDirect(p1)

tt.AddPeer(p1, GossipSubID_v10)
tt.AddPeer(p2, GossipSubID_v10)
Expand All @@ -69,6 +67,11 @@ func TestTagTracerDirectPeerTags(t *testing.T) {
t.Fatal("expected non-direct peer to be unprotected")
}
}

tt.unprotectDirect(p1)
if cmgr.IsProtected(p1, tag) {
t.Fatal("expected direct peer to not be protected")
}
}

func TestTagTracerDeliveryTags(t *testing.T) {
Expand Down