Skip to content

Commit 3a3eb4f

Browse files
committed
watchset: Fix WatchSet.Wait() to not return empty closed channels
[ upstream commit da3849c ] As the comment on Wait() implied the intended semantics for Wait() was to wait for [settleTime] after the first channel has closed, but instead it returned with empty closed channels when [settleTime] expired. Fix this by first watching for 'ctx.Done()' and when a first closed channel is encountered switch to a context with a [settleTime] timeout to gather more channels. Fixes: f0c8822 ("Use reflect.Select for WatchSet") Signed-off-by: Jussi Maki <jussi.maki@isovalent.com>
1 parent 4accd18 commit 3a3eb4f

File tree

2 files changed

+65
-16
lines changed

2 files changed

+65
-16
lines changed

watchset.go

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,9 @@ func (ws *WatchSet) Merge(other *WatchSet) {
7777
// Wait for channels in the watch set to close or the context is cancelled.
7878
// After the first closed channel is seen Wait will wait [settleTime] for
7979
// more closed channels.
80+
// If [settleTime] is 0 waits until [ctx] cancelled or any channel closes.
8081
// Returns the closed channels and removes them from the set.
8182
func (ws *WatchSet) Wait(ctx context.Context, settleTime time.Duration) ([]<-chan struct{}, error) {
82-
innerCtx, cancel := context.WithTimeout(ctx, settleTime)
83-
defer cancel()
84-
8583
ws.mu.Lock()
8684
defer ws.mu.Unlock()
8785

@@ -94,10 +92,14 @@ func (ws *WatchSet) Wait(ctx context.Context, settleTime time.Duration) ([]<-cha
9492
// Construct []SelectCase slice. Reuse the previous allocation.
9593
ws.cases = slices.Grow(ws.cases, 1+len(ws.chans))
9694
cases := ws.cases[:1+len(ws.chans)]
95+
96+
// Add [ctx.Done()] to stop when [ctx] is cancelled.
9797
cases[0] = reflect.SelectCase{
9898
Dir: reflect.SelectRecv,
99-
Chan: reflect.ValueOf(innerCtx.Done()),
99+
Chan: reflect.ValueOf(ctx.Done()),
100100
}
101+
102+
// Add the cases from the watch set.
101103
casesIndex := 1
102104
for ch := range ws.chans {
103105
cases[casesIndex] = reflect.SelectCase{
@@ -109,21 +111,42 @@ func (ws *WatchSet) Wait(ctx context.Context, settleTime time.Duration) ([]<-cha
109111

110112
var closedChannels []<-chan struct{}
111113

112-
// Collect closed channels until [innerCtx] is cancelled.
113-
for {
114+
// At the end remove the closed channels from the watch set.
115+
defer func() {
116+
for _, ch := range closedChannels {
117+
delete(ws.chans, ch)
118+
}
119+
}()
120+
121+
// Wait for the first channel to close and shift it out from [cases]
122+
chosen, _, _ := reflect.Select(cases)
123+
if chosen == 0 {
124+
return nil, ctx.Err()
125+
}
126+
closedChannels = append(closedChannels, cases[chosen].Chan.Interface().(<-chan struct{}))
127+
cases[chosen] = cases[len(cases)-1]
128+
cases = cases[:len(cases)-1]
129+
130+
// If nothing else than context channel remains or we don't want to wait for further channels
131+
// to close then we're done.
132+
if len(cases) == 1 || settleTime == 0 {
133+
return closedChannels, nil
134+
}
135+
136+
// Swap out the 'ctx.Done()' to a context that times out when [settleTime] expires.
137+
settleCtx, cancel := context.WithTimeout(ctx, settleTime)
138+
defer cancel()
139+
cases[0].Chan = reflect.ValueOf(settleCtx.Done())
140+
141+
for len(cases) > 1 {
114142
chosen, _, _ := reflect.Select(cases)
115-
if chosen == 0 /* == innerCtx.Done() */ {
143+
if chosen == 0 /* settleCtx.Done() */ {
116144
break
117145
}
118146
closedChannels = append(closedChannels, cases[chosen].Chan.Interface().(<-chan struct{}))
119147
cases[chosen] = cases[len(cases)-1]
120148
cases = cases[:len(cases)-1]
121149
}
122150

123-
// Remove the closed channels from the watch set.
124-
for _, ch := range closedChannels {
125-
delete(ws.chans, ch)
126-
}
127-
128151
return closedChannels, ctx.Err()
129152
}

watchset_test.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ func TestWatchSet(t *testing.T) {
2222
// Empty watch set, cancelled context.
2323
ctx, cancel := context.WithCancel(context.Background())
2424
cancel()
25-
ch, err := ws.Wait(ctx, time.Second)
25+
chs, err := ws.Wait(ctx, time.Second)
2626
require.ErrorIs(t, err, context.Canceled)
27-
require.Nil(t, ch)
27+
require.Empty(t, chs)
2828

2929
// Few channels, cancelled context.
3030
ch1 := make(chan struct{})
@@ -33,9 +33,35 @@ func TestWatchSet(t *testing.T) {
3333
ws.Add(ch1, ch2, ch3)
3434
ctx, cancel = context.WithCancel(context.Background())
3535
cancel()
36-
ch, err = ws.Wait(ctx, time.Second)
36+
chs, err = ws.Wait(ctx, time.Second)
3737
require.ErrorIs(t, err, context.Canceled)
38-
require.Nil(t, ch)
38+
require.Empty(t, chs)
39+
40+
// Few channels, timed out context. With tiny 'settleTime' we wait for the context to cancel.
41+
duration := 10 * time.Millisecond
42+
ctx, cancel = context.WithTimeout(context.Background(), duration)
43+
t0 := time.Now()
44+
chs, err = ws.Wait(ctx, time.Nanosecond)
45+
require.ErrorIs(t, err, context.DeadlineExceeded)
46+
require.Empty(t, chs)
47+
require.True(t, time.Since(t0) > duration, "expected to wait until context cancels")
48+
cancel()
49+
50+
// One closed channel. Should wait until 'settleTime' expires.
51+
close(ch1)
52+
t0 = time.Now()
53+
chs, err = ws.Wait(context.Background(), duration)
54+
require.NoError(t, err)
55+
require.ElementsMatch(t, chs, []<-chan struct{}{ch1})
56+
require.True(t, time.Since(t0) > duration, "expected to wait until settle time expires")
57+
58+
// One closed channel, 0 wait time.
59+
ws = NewWatchSet()
60+
ws.Add(ch2)
61+
close(ch2)
62+
chs, err = ws.Wait(context.Background(), 0)
63+
require.NoError(t, err)
64+
require.ElementsMatch(t, chs, []<-chan struct{}{ch2})
3965

4066
// Many channels
4167
for _, numChans := range []int{2, 16, 31, 1024} {

0 commit comments

Comments
 (0)