Skip to content

Commit dbb3d2c

Browse files
authored
Merge pull request #462 from libp2p/fix/observe-context-in-message-sender
fix: obey the context when sending messages to peers
2 parents a92f79b + 0b02938 commit dbb3d2c

File tree

4 files changed

+88
-6
lines changed

4 files changed

+88
-6
lines changed

ctx_mutex.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package dht
2+
3+
import (
4+
"context"
5+
)
6+
7+
type ctxMutex chan struct{}
8+
9+
func newCtxMutex() ctxMutex {
10+
return make(ctxMutex, 1)
11+
}
12+
13+
func (m ctxMutex) Lock(ctx context.Context) error {
14+
select {
15+
case m <- struct{}{}:
16+
return nil
17+
case <-ctx.Done():
18+
return ctx.Err()
19+
}
20+
}
21+
22+
func (m ctxMutex) Unlock() {
23+
select {
24+
case <-m:
25+
default:
26+
panic("not locked")
27+
}
28+
}

dht_net.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
246246
dht.smlk.Unlock()
247247
return ms, nil
248248
}
249-
ms = &messageSender{p: p, dht: dht}
249+
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
250250
dht.strmap[p] = ms
251251
dht.smlk.Unlock()
252252

@@ -274,7 +274,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
274274
type messageSender struct {
275275
s network.Stream
276276
r msgio.ReadCloser
277-
lk sync.Mutex
277+
lk ctxMutex
278278
p peer.ID
279279
dht *IpfsDHT
280280

@@ -294,8 +294,11 @@ func (ms *messageSender) invalidate() {
294294
}
295295

296296
func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
297-
ms.lk.Lock()
297+
if err := ms.lk.Lock(ctx); err != nil {
298+
return err
299+
}
298300
defer ms.lk.Unlock()
301+
299302
if err := ms.prep(ctx); err != nil {
300303
ms.invalidate()
301304
return err
@@ -328,8 +331,11 @@ func (ms *messageSender) prep(ctx context.Context) error {
328331
const streamReuseTries = 3
329332

330333
func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
331-
ms.lk.Lock()
334+
if err := ms.lk.Lock(ctx); err != nil {
335+
return err
336+
}
332337
defer ms.lk.Unlock()
338+
333339
retry := false
334340
for {
335341
if err := ms.prep(ctx); err != nil {
@@ -363,8 +369,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro
363369
}
364370

365371
func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
366-
ms.lk.Lock()
372+
if err := ms.lk.Lock(ctx); err != nil {
373+
return nil, err
374+
}
367375
defer ms.lk.Unlock()
376+
368377
retry := false
369378
for {
370379
if err := ms.prep(ctx); err != nil {

ext_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,49 @@ import (
1818
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
1919
)
2020

21+
func TestHang(t *testing.T) {
22+
ctx := context.Background()
23+
mn, err := mocknet.FullMeshConnected(ctx, 2)
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
hosts := mn.Hosts()
28+
29+
os := []opts.Option{opts.DisableAutoRefresh()}
30+
d, err := New(ctx, hosts[0], os...)
31+
if err != nil {
32+
t.Fatal(err)
33+
}
34+
// Hang on every request.
35+
hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) {
36+
defer s.Reset()
37+
<-ctx.Done()
38+
})
39+
d.Update(ctx, hosts[1].ID())
40+
41+
ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second)
42+
defer cancel1()
43+
44+
peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString())
45+
if err != nil {
46+
t.Fatal(err)
47+
}
48+
49+
time.Sleep(100 * time.Millisecond)
50+
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
51+
defer cancel2()
52+
_ = d.Provide(ctx2, testCaseCids[0], true)
53+
if ctx2.Err() != context.DeadlineExceeded {
54+
t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err())
55+
}
56+
select {
57+
case <-peers:
58+
t.Error("GetClosestPeers should not have returned yet")
59+
default:
60+
}
61+
62+
}
63+
2164
func TestGetFailures(t *testing.T) {
2265
if testing.Short() {
2366
t.SkipNow()

notif.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package dht
22

33
import (
4+
"context"
5+
46
"github.com/libp2p/go-libp2p-core/helpers"
57
"github.com/libp2p/go-libp2p-core/network"
68

@@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {
130132

131133
// Do this asynchronously as ms.lk can block for a while.
132134
go func() {
133-
ms.lk.Lock()
135+
ms.lk.Lock(context.Background())
134136
defer ms.lk.Unlock()
135137
ms.invalidate()
136138
}()

0 commit comments

Comments
 (0)