Skip to content

Commit e20a16e

Browse files
authored
Fix race in port-forward (#418)
1 parent 5691835 commit e20a16e

File tree

4 files changed

+92
-5
lines changed

4 files changed

+92
-5
lines changed

internal/concurrentmap/concurrentmap.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,20 @@ func (cmap *ConcurrentMap[T]) Delete(key string) {
3737

3838
delete(cmap.nonConcurrentMap, key)
3939
}
40+
41+
func (cmap *ConcurrentMap[T]) DeleteIf(key string, predicate func(T) bool) bool {
42+
cmap.mtx.Lock()
43+
defer cmap.mtx.Unlock()
44+
45+
value, ok := cmap.nonConcurrentMap[key]
46+
if !ok {
47+
return false
48+
}
49+
if !predicate(value) {
50+
return false
51+
}
52+
53+
delete(cmap.nonConcurrentMap, key)
54+
55+
return true
56+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package concurrentmap
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestDeleteIf(t *testing.T) {
10+
cmap := NewConcurrentMap[int]()
11+
cmap.Store("a", 1)
12+
13+
deleted := cmap.DeleteIf("a", func(value int) bool {
14+
return value == 1
15+
})
16+
require.True(t, deleted)
17+
18+
_, ok := cmap.Load("a")
19+
require.False(t, ok)
20+
}
21+
22+
func TestDeleteIfPredicateFalse(t *testing.T) {
23+
cmap := NewConcurrentMap[int]()
24+
cmap.Store("a", 1)
25+
26+
deleted := cmap.DeleteIf("a", func(value int) bool {
27+
return value == 2
28+
})
29+
require.False(t, deleted)
30+
31+
value, ok := cmap.Load("a")
32+
require.True(t, ok)
33+
require.Equal(t, 1, value)
34+
}

internal/controller/notifier/notifier.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,19 @@ func NewNotifier(logger *zap.SugaredLogger) *Notifier {
3131
func (watcher *Notifier) Register(ctx context.Context, worker string) (chan *rpc.WatchInstruction, func()) {
3232
subCtx, cancel := context.WithCancel(ctx)
3333
workerCh := make(chan *rpc.WatchInstruction)
34-
35-
watcher.logger.Debugf("registering worker %s", worker)
36-
watcher.workers.Store(worker, &WorkerSlot{
34+
slot := &WorkerSlot{
3735
ctx: subCtx,
3836
ch: workerCh,
39-
})
37+
}
38+
39+
watcher.logger.Debugf("registering worker %s", worker)
40+
watcher.workers.Store(worker, slot)
4041

4142
return workerCh, func() {
4243
watcher.logger.Debugf("deleting worker %s", worker)
43-
watcher.workers.Delete(worker)
44+
watcher.workers.DeleteIf(worker, func(current *WorkerSlot) bool {
45+
return current == slot
46+
})
4447
cancel()
4548
}
4649
}

internal/controller/notifier/notifier_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,36 @@ func TestNotifier(t *testing.T) {
4747

4848
wg.Wait()
4949
}
50+
51+
func TestNotifierReRegisterKeepsNewestSlot(t *testing.T) {
52+
ctx := context.Background()
53+
watcher := notifier.NewNotifier(zap.NewNop().Sugar())
54+
55+
const worker = "worker-a"
56+
57+
_, staleCancel := watcher.Register(ctx, worker)
58+
newestCh, newestCancel := watcher.Register(ctx, worker)
59+
defer newestCancel()
60+
61+
// Simulate stale connection cleanup arriving after the worker has already re-registered.
62+
staleCancel()
63+
64+
notifyCtx, notifyCancel := context.WithTimeout(ctx, 300*time.Millisecond)
65+
defer notifyCancel()
66+
67+
notifyErrCh := make(chan error, 1)
68+
go func() {
69+
notifyErrCh <- watcher.Notify(notifyCtx, worker, nil)
70+
}()
71+
72+
select {
73+
case <-newestCh:
74+
case err := <-notifyErrCh:
75+
require.NoError(t, err)
76+
t.Fatal("notify returned before delivering message to newest registration")
77+
case <-time.After(time.Second):
78+
t.Fatal("timed out waiting for notify delivery")
79+
}
80+
81+
require.NoError(t, <-notifyErrCh)
82+
}

0 commit comments

Comments
 (0)