Skip to content

Commit 2ce4fe8

Browse files
committed
cancell enqueued rpcs on receiving IDontWant
1 parent ed53c17 commit 2ce4fe8

File tree

6 files changed

+233
-17
lines changed

6 files changed

+233
-17
lines changed

comm.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (p *PubSub) handlePeerDead(s network.Stream) {
165165
}
166166

167167
func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue) {
168-
writeRpc := func(rpc *RPC) error {
168+
writeRpc := func(rpc *pb.RPC) error {
169169
size := uint64(rpc.Size())
170170

171171
buf := pool.Get(varint.UvarintSize(size) + int(size))
@@ -193,8 +193,11 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou
193193
p.logger.Debug("error popping message from the queue to send to peer", "peer", s.Conn().RemotePeer(), "err", err)
194194
return
195195
}
196+
if rpc.Size() == 0 {
197+
continue
198+
}
196199

197-
err = writeRpc(rpc)
200+
err = writeRpc(&rpc.RPC)
198201
if err != nil {
199202
s.Reset()
200203
p.logger.Debug("error writing message to peer", "peer", s.Conn().RemotePeer(), "err", err)
@@ -215,6 +218,10 @@ func rpcWithMessages(msgs ...*pb.Message) *RPC {
215218
return &RPC{RPC: pb.RPC{Publish: msgs}}
216219
}
217220

221+
func rpcWithMessageAndMsgID(msg *pb.Message, msgID string) *RPC {
222+
return &RPC{RPC: pb.RPC{Publish: []*pb.Message{msg}}, MsgIDs: []string{msgID}}
223+
}
224+
218225
func rpcWithControl(msgs []*pb.Message,
219226
ihave []*pb.ControlIHave,
220227
iwant []*pb.ControlIWant,

gossipsub.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,9 @@ func (gs *GossipSubRouter) handleIDontWant(p peer.ID, ctl *pb.ControlMessage) {
11631163
gs.peerdontwant[p]++
11641164

11651165
totalUnwantedIds := 0
1166+
// Collect message IDs for cancellation
1167+
var msgIDsToCancel []string
1168+
11661169
// Remember all the unwanted message ids
11671170
mainIDWLoop:
11681171
for _, idontwant := range ctl.GetIdontwant() {
@@ -1175,8 +1178,14 @@ mainIDWLoop:
11751178

11761179
totalUnwantedIds++
11771180
gs.unwanted[p][computeChecksum(mid)] = gs.params.IDontWantMessageTTL
1181+
msgIDsToCancel = append(msgIDsToCancel, mid)
11781182
}
11791183
}
1184+
1185+
// Cancel these messages in the RPC queue if it exists
1186+
if queue, ok := gs.p.peers[p]; ok && len(msgIDsToCancel) > 0 {
1187+
queue.CancelMessages(msgIDsToCancel)
1188+
}
11801189
}
11811190

11821191
func (gs *GossipSubRouter) addBackoff(p peer.ID, topic string, isUnsubscribe bool) {
@@ -1370,7 +1379,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
13701379
}
13711380
}
13721381

1373-
out := rpcWithMessages(msg.Message)
1382+
out := rpcWithMessageAndMsgID(msg.Message, gs.p.idGen.ID(msg))
13741383
for pid := range tosend {
13751384
if pid == from || pid == peer.ID(msg.GetFrom()) {
13761385
continue

gossipsub_spam_test.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -930,20 +930,25 @@ func TestGossipsubHandleIDontwantSpam(t *testing.T) {
930930
rPid := hosts[1].ID()
931931
ctrlMessage := &pb.ControlMessage{Idontwant: []*pb.ControlIDontWant{{MessageIDs: idwIds}}}
932932
grt := psubs[0].rt.(*GossipSubRouter)
933-
grt.handleIDontWant(rPid, ctrlMessage)
933+
completed := make(chan struct{})
934+
psubs[0].eval <- func() {
935+
grt.handleIDontWant(rPid, ctrlMessage)
934936

935-
if grt.peerdontwant[rPid] != 1 {
936-
t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid])
937-
}
938-
mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1)
939-
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok {
940-
t.Errorf("Desired message id was not stored in the unwanted map: %s", mid)
941-
}
937+
if grt.peerdontwant[rPid] != 1 {
938+
t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid])
939+
}
940+
mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1)
941+
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok {
942+
t.Errorf("Desired message id was not stored in the unwanted map: %s", mid)
943+
}
942944

943-
mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength)
944-
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok {
945-
t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid)
945+
mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength)
946+
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok {
947+
t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid)
948+
}
949+
close(completed)
946950
}
951+
<-completed
947952
}
948953

949954
type mockGSOnRead func(writeMsg func(*pb.RPC), irpc *pb.RPC)

pubsub.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ func (m *Message) GetFrom() peer.ID {
257257

258258
type RPC struct {
259259
pb.RPC
260+
// MsgIDs are the ids of the messages in the rpc. MsgID[i] = id(rpc.Publish[i])
261+
MsgIDs []string
260262

261263
// unexported on purpose, not sending this over the wire
262264
from peer.ID
@@ -274,6 +276,16 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {
274276

275277
messagesInNextRPC := 0
276278
messageSlice := rpc.Publish
279+
var msgIDSlice []string
280+
if len(rpc.MsgIDs) != len(rpc.Publish) {
281+
if len(rpc.MsgIDs) == 0 {
282+
msgIDSlice = make([]string, len(rpc.Publish))
283+
} else {
284+
panic("MsgIDs and Publish have different lengths")
285+
}
286+
} else {
287+
msgIDSlice = rpc.MsgIDs
288+
}
277289

278290
// Merge/Append publish messages. This pattern is optimized compared the
279291
// the patterns for other fields because this is the common cause for
@@ -285,7 +297,10 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {
285297
// The message doesn't fit. Let's set the messages that did fit
286298
// into this RPC, yield it, then make a new one
287299
nextRPC.Publish = messageSlice[:messagesInNextRPC]
300+
nextRPC.MsgIDs = msgIDSlice[:messagesInNextRPC]
288301
messageSlice = messageSlice[messagesInNextRPC:]
302+
msgIDSlice = msgIDSlice[messagesInNextRPC:]
303+
289304
if !yield(nextRPC) {
290305
return
291306
}
@@ -303,6 +318,7 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {
303318
// packing this RPC, but we avoid successively calling .Size()
304319
// on the messages for the next parts.
305320
nextRPC.Publish = messageSlice[:messagesInNextRPC]
321+
nextRPC.MsgIDs = msgIDSlice[:messagesInNextRPC]
306322
if !yield(nextRPC) {
307323
return
308324
}

rpc_queue.go

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package pubsub
33
import (
44
"context"
55
"errors"
6+
"slices"
67
"sync"
78
)
89

@@ -50,20 +51,41 @@ type rpcQueue struct {
5051
dataAvailable sync.Cond
5152
spaceAvailable sync.Cond
5253
// Mutex used to access queue
53-
queueMu sync.Mutex
54-
queue priorityQueue
54+
queueMu sync.Mutex
55+
queue priorityQueue
56+
queuedMsgIDs map[string]int // message ids in queue
57+
cancelledIDs map[string]struct{} // messages ids that'll be dropped before sending
5558

5659
closed bool
5760
maxSize int
5861
}
5962

6063
func newRpcQueue(maxSize int) *rpcQueue {
61-
q := &rpcQueue{maxSize: maxSize}
64+
q := &rpcQueue{
65+
maxSize: maxSize,
66+
queuedMsgIDs: make(map[string]int),
67+
cancelledIDs: make(map[string]struct{}),
68+
}
6269
q.dataAvailable.L = &q.queueMu
6370
q.spaceAvailable.L = &q.queueMu
6471
return q
6572
}
6673

74+
// CancelMessages marks the given message IDs for cancellation only if they are already in queue.
75+
func (q *rpcQueue) CancelMessages(msgIDs []string) {
76+
q.queueMu.Lock()
77+
defer q.queueMu.Unlock()
78+
79+
for _, id := range msgIDs {
80+
if id != "" {
81+
// Only cancel messages that are actually in the queue
82+
if count := q.queuedMsgIDs[id]; count > 0 {
83+
q.cancelledIDs[id] = struct{}{}
84+
}
85+
}
86+
}
87+
}
88+
6789
func (q *rpcQueue) Push(rpc *RPC, block bool) error {
6890
return q.push(rpc, false, block)
6991
}
@@ -91,11 +113,17 @@ func (q *rpcQueue) push(rpc *RPC, urgent bool, block bool) error {
91113
return ErrQueueFull
92114
}
93115
}
116+
94117
if urgent {
95118
q.queue.PriorityPush(rpc)
96119
} else {
97120
q.queue.NormalPush(rpc)
98121
}
122+
for _, id := range rpc.MsgIDs {
123+
if id != "" {
124+
q.queuedMsgIDs[id]++
125+
}
126+
}
99127

100128
q.dataAvailable.Signal()
101129
return nil
@@ -133,10 +161,48 @@ func (q *rpcQueue) Pop(ctx context.Context) (*RPC, error) {
133161
}
134162
}
135163
rpc := q.queue.Pop()
164+
rpc = q.handleCancellations(rpc)
136165
q.spaceAvailable.Signal()
137166
return rpc, nil
138167
}
139168

169+
func (q *rpcQueue) handleCancellations(rpc *RPC) *RPC {
170+
hasCancellations := false
171+
for _, msgID := range rpc.MsgIDs {
172+
q.queuedMsgIDs[msgID]--
173+
if q.queuedMsgIDs[msgID] <= 0 {
174+
delete(q.queuedMsgIDs, msgID)
175+
}
176+
if _, ok := q.cancelledIDs[msgID]; ok {
177+
hasCancellations = true
178+
}
179+
}
180+
if hasCancellations {
181+
// clone the RPC parts that we'll modify. It may be shared with other queues.
182+
newRPC := *rpc
183+
newRPC.RPC.Publish = slices.Clone(rpc.RPC.Publish)
184+
newRPC.MsgIDs = slices.Clone(rpc.MsgIDs)
185+
rpc = &newRPC
186+
for i, msgID := range slices.Backward(newRPC.MsgIDs) {
187+
if msgID == "" {
188+
continue
189+
}
190+
_, ok := q.cancelledIDs[msgID]
191+
if !ok {
192+
continue
193+
}
194+
rpc.RPC.Publish[i] = nil
195+
rpc.MsgIDs[i] = ""
196+
if q.queuedMsgIDs[msgID] <= 0 {
197+
delete(q.cancelledIDs, msgID)
198+
}
199+
rpc.RPC.Publish = slices.Delete(rpc.RPC.Publish, i, i+1) // this slice is small
200+
rpc.MsgIDs = slices.Delete(rpc.MsgIDs, i, i+1)
201+
}
202+
}
203+
return rpc
204+
}
205+
140206
func (q *rpcQueue) Close() {
141207
q.queueMu.Lock()
142208
defer q.queueMu.Unlock()

rpc_queue_test.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ package pubsub
22

33
import (
44
"context"
5+
"fmt"
6+
"slices"
57
"testing"
68
"time"
9+
10+
pb "github.com/libp2p/go-libp2p-pubsub/pb"
711
)
812

913
func TestNewRpcQueue(t *testing.T) {
@@ -227,3 +231,112 @@ func TestRpcQueueCancelPop(t *testing.T) {
227231
t.Fatalf("rpc queue Pop returns wrong error when it's cancelled")
228232
}
229233
}
234+
235+
func TestRPCQueueCancellations(t *testing.T) {
236+
maxSize := 32
237+
q := newRpcQueue(maxSize)
238+
239+
getMesssages := func(n int) ([]*pb.Message, []string) {
240+
msgs := make([]*pb.Message, n)
241+
msgIDs := make([]string, n)
242+
for i := range msgs {
243+
msgs[i] = &pb.Message{Data: []byte(fmt.Sprintf("message%d", i+1))}
244+
msgIDs[i] = fmt.Sprintf("msg%d", i+1)
245+
}
246+
return msgs, msgIDs
247+
}
248+
249+
t.Run("cancel all", func(t *testing.T) {
250+
msgs, msgIDs := getMesssages(10)
251+
rpc := &RPC{
252+
RPC: pb.RPC{Publish: slices.Clone(msgs)},
253+
MsgIDs: slices.Clone(msgIDs),
254+
}
255+
q.Push(rpc, true)
256+
q.CancelMessages(msgIDs)
257+
popped, err := q.Pop(context.Background())
258+
if err != nil {
259+
t.Fatalf("failed to pop RPC: %v", err)
260+
}
261+
if len(popped.Publish) != 0 {
262+
t.Fatalf("expected popped.Publish to be empty, got %v", popped.Publish)
263+
}
264+
if len(popped.MsgIDs) != 0 {
265+
t.Fatalf("expected popped.MsgIDs to be empty, got %v", popped.MsgIDs)
266+
}
267+
if len(q.queuedMsgIDs) != 0 {
268+
t.Fatalf("expected q.queuedMsgIDs to be empty, got %v", q.queuedMsgIDs)
269+
}
270+
if len(q.cancelledIDs) != 0 {
271+
t.Fatalf("expected q.cancelledIDs to be empty, got %v", q.cancelledIDs)
272+
}
273+
})
274+
275+
t.Run("cancel some", func(t *testing.T) {
276+
msgs, msgIDs := getMesssages(10)
277+
rpc := &RPC{
278+
RPC: pb.RPC{Publish: slices.Clone(msgs)},
279+
MsgIDs: slices.Clone(msgIDs),
280+
}
281+
q.Push(rpc, true)
282+
q.CancelMessages(msgIDs[:3])
283+
popped, err := q.Pop(context.Background())
284+
if err != nil {
285+
t.Fatalf("failed to pop RPC: %v", err)
286+
}
287+
if !slices.Equal(msgs[3:], popped.Publish) {
288+
t.Fatalf("expected popped.Publish to be %v, got %v", msgs[3:], popped.Publish)
289+
}
290+
if !slices.Equal(msgIDs[3:], popped.MsgIDs) {
291+
t.Fatalf("expected popped.MsgIDs to be %v, got %v", msgIDs[3:], popped.MsgIDs)
292+
}
293+
if len(q.queuedMsgIDs) != 0 {
294+
t.Fatalf("expected q.queuedMsgIDs to be empty, got %v", q.queuedMsgIDs)
295+
}
296+
if len(q.cancelledIDs) != 0 {
297+
t.Fatalf("expected q.cancelledIDs to be empty, got %v", q.cancelledIDs)
298+
}
299+
})
300+
301+
t.Run("cancel duplicate", func(t *testing.T) {
302+
msgs, msgIDs := getMesssages(10)
303+
rpc := &RPC{
304+
RPC: pb.RPC{Publish: slices.Clone(msgs)},
305+
MsgIDs: slices.Clone(msgIDs),
306+
}
307+
q.Push(rpc, true)
308+
rpc2 := &RPC{
309+
RPC: pb.RPC{Publish: slices.Clone(msgs)},
310+
MsgIDs: slices.Clone(msgIDs),
311+
}
312+
q.Push(rpc2, true)
313+
q.CancelMessages(msgIDs[:3])
314+
popped, err := q.Pop(context.Background())
315+
if err != nil {
316+
t.Fatalf("failed to pop RPC: %v", err)
317+
}
318+
if !slices.Equal(msgs[3:], popped.Publish) {
319+
t.Fatalf("expected popped.Publish to be %v, got %v", msgs[3:], popped.Publish)
320+
}
321+
if !slices.Equal(msgIDs[3:], popped.MsgIDs) {
322+
t.Fatalf("expected popped.MsgIDs to be %v, got %v", msgIDs[3:], popped.MsgIDs)
323+
}
324+
325+
popped, err = q.Pop(context.Background())
326+
if err != nil {
327+
t.Fatalf("failed to pop RPC: %v", err)
328+
}
329+
if !slices.Equal(msgs[3:], popped.Publish) {
330+
t.Fatalf("expected popped.Publish to be %v, got %v", msgs[3:], popped.Publish)
331+
}
332+
if !slices.Equal(msgIDs[3:], popped.MsgIDs) {
333+
t.Fatalf("expected popped.MsgIDs to be %v, got %v", msgIDs[3:], popped.MsgIDs)
334+
}
335+
if len(q.queuedMsgIDs) != 0 {
336+
t.Fatalf("expected q.queuedMsgIDs to be empty, got %v", q.queuedMsgIDs)
337+
}
338+
if len(q.cancelledIDs) != 0 {
339+
t.Fatalf("expected q.cancelledIDs to be empty, got %v", q.cancelledIDs)
340+
}
341+
})
342+
}

0 commit comments

Comments
 (0)