Skip to content

Commit 3db8cfa

Browse files
committed
Added a resettable signal
Signed-off-by: Jakob Haahr Taankvist <jht@uber.com>
1 parent 7e02811 commit 3db8cfa

File tree

2 files changed

+240
-0
lines changed

2 files changed

+240
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package sync
2+
3+
import (
4+
"context"
5+
"errors"
6+
"sync"
7+
)
8+
9+
// ErrReset is returned by Wait() when Reset() is called while goroutines are waiting
10+
var ErrReset = errors.New("signal was reset")
11+
12+
// resettableSignal is a synchronization primitive that allows waiting for a one-time
13+
// signal with context support, and can be reset to wait for a new signal.
14+
// Similar to sync.WaitGroup but for single completion events with reset capability.
15+
type resettableSignal struct {
16+
mu sync.Mutex
17+
doneCh chan struct{}
18+
resetCh chan struct{}
19+
done bool
20+
}
21+
22+
// ResettableSignal must be created via NewResettableSignal(). Zero value is invalid.
23+
type ResettableSignal = *resettableSignal
24+
25+
// NewResettableSignal creates a new resettable signal in waiting state
26+
func NewResettableSignal() *resettableSignal {
27+
return &resettableSignal{
28+
doneCh: make(chan struct{}),
29+
resetCh: make(chan struct{}),
30+
}
31+
}
32+
33+
// Done signals that the event has completed. Safe to call multiple times (idempotent).
34+
func (s *resettableSignal) Done() {
35+
s.mu.Lock()
36+
defer s.mu.Unlock()
37+
38+
if !s.done {
39+
s.done = true
40+
close(s.doneCh)
41+
}
42+
}
43+
44+
// Wait blocks until either Done() is called, the context is cancelled, or Reset() is called.
45+
// Returns:
46+
// - nil if Done() was called
47+
// - ctx.Err() if context was cancelled
48+
// - ErrReset if Reset() was called while waiting
49+
func (s *resettableSignal) Wait(ctx context.Context) error {
50+
s.mu.Lock()
51+
doneCh := s.doneCh
52+
resetCh := s.resetCh
53+
done := s.done
54+
s.mu.Unlock()
55+
56+
// Fast path: already done
57+
if done {
58+
return nil
59+
}
60+
61+
select {
62+
case <-doneCh:
63+
return nil
64+
case <-resetCh:
65+
return ErrReset
66+
case <-ctx.Done():
67+
return ctx.Err()
68+
}
69+
}
70+
71+
// Reset resets the signal to waiting state. Any goroutines currently blocked in Wait()
72+
// will immediately be unblocked with ErrReset.
73+
func (s *resettableSignal) Reset() {
74+
s.mu.Lock()
75+
defer s.mu.Unlock()
76+
77+
// Close reset channel to unblock any waiters (they'll get ErrReset)
78+
// Only close if not already done (to avoid closing a closed channel)
79+
if !s.done {
80+
close(s.resetCh)
81+
}
82+
83+
// Create new channels and reset done flag
84+
s.doneCh = make(chan struct{})
85+
s.resetCh = make(chan struct{})
86+
s.done = false
87+
}
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package sync
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestResettableSignal_DoneWait(t *testing.T) {
13+
signal := NewResettableSignal()
14+
15+
signal.Done()
16+
17+
ctx := context.Background()
18+
err := signal.Wait(ctx)
19+
assert.NoError(t, err)
20+
}
21+
22+
func TestResettableSignal_WaitDone(t *testing.T) {
23+
signal := NewResettableSignal()
24+
25+
done := make(chan error)
26+
go func() {
27+
done <- signal.Wait(context.Background())
28+
}()
29+
30+
// Give goroutine time to start waiting
31+
time.Sleep(10 * time.Millisecond)
32+
33+
signal.Done()
34+
35+
select {
36+
case err := <-done:
37+
assert.NoError(t, err)
38+
case <-time.After(1 * time.Second):
39+
t.Fatal("Wait did not complete after Done")
40+
}
41+
}
42+
43+
func TestResettableSignal_ContextCancellation(t *testing.T) {
44+
signal := NewResettableSignal()
45+
46+
ctx, cancel := context.WithCancel(context.Background())
47+
48+
done := make(chan error)
49+
go func() {
50+
done <- signal.Wait(ctx)
51+
}()
52+
53+
time.Sleep(10 * time.Millisecond)
54+
55+
cancel()
56+
57+
select {
58+
case err := <-done:
59+
assert.Error(t, err)
60+
assert.Equal(t, context.Canceled, err)
61+
case <-time.After(1 * time.Second):
62+
t.Fatal("Wait did not complete after context cancellation")
63+
}
64+
}
65+
66+
func TestResettableSignal_ResetWhileWaiting(t *testing.T) {
67+
signal := NewResettableSignal()
68+
69+
done := make(chan error)
70+
go func() {
71+
done <- signal.Wait(context.Background())
72+
}()
73+
74+
time.Sleep(10 * time.Millisecond)
75+
76+
signal.Reset()
77+
78+
select {
79+
case err := <-done:
80+
assert.Error(t, err)
81+
assert.True(t, errors.Is(err, ErrReset), "expected ErrReset, got %v", err)
82+
case <-time.After(1 * time.Second):
83+
t.Fatal("Wait did not complete after Reset")
84+
}
85+
}
86+
87+
func TestResettableSignal_MultipleWaitersReset(t *testing.T) {
88+
signal := NewResettableSignal()
89+
90+
const numWaiters = 5
91+
results := make([]chan error, numWaiters)
92+
for i := 0; i < numWaiters; i++ {
93+
results[i] = make(chan error)
94+
go func() {
95+
results[i] <- signal.Wait(context.Background())
96+
}()
97+
}
98+
99+
// Give goroutines time to start waiting
100+
time.Sleep(10 * time.Millisecond)
101+
102+
signal.Reset()
103+
104+
for i, ch := range results {
105+
select {
106+
case err := <-ch:
107+
assert.Error(t, err, "waiter %d should get error", i)
108+
assert.True(t, errors.Is(err, ErrReset), "waiter %d: expected ErrReset, got %v", i, err)
109+
case <-time.After(1 * time.Second):
110+
t.Fatalf("Waiter %d did not complete after Reset", i)
111+
}
112+
}
113+
}
114+
115+
func TestResettableSignal_ResetAfterDone(t *testing.T) {
116+
signal := NewResettableSignal()
117+
118+
// Terminate and reset the signal
119+
signal.Done()
120+
signal.Reset()
121+
122+
// Now signal should be waiting again
123+
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
124+
defer cancel()
125+
126+
err := signal.Wait(ctx)
127+
assert.Error(t, err)
128+
assert.ErrorIs(t, err, context.DeadlineExceeded)
129+
}
130+
131+
func TestResettableSignal_ResetThenDoneThenWait(t *testing.T) {
132+
signal := NewResettableSignal()
133+
134+
// Do a full cycle
135+
signal.Done()
136+
signal.Reset()
137+
signal.Done()
138+
139+
err := signal.Wait(context.Background())
140+
assert.NoError(t, err)
141+
}
142+
143+
func TestResettableSignal_IdempotentDone(t *testing.T) {
144+
signal := NewResettableSignal()
145+
146+
// Call Done multiple times
147+
signal.Done()
148+
signal.Done()
149+
signal.Done()
150+
151+
err := signal.Wait(context.Background())
152+
assert.NoError(t, err)
153+
}

0 commit comments

Comments
 (0)