Skip to content

Commit f1b5ffb

Browse files
authored
Add custom WaitGroup
1 parent 6cd60d1 commit f1b5ffb

File tree

4 files changed

+106
-0
lines changed

4 files changed

+106
-0
lines changed

internal/sync/waitgroup.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package sync
2+
3+
type WaitGroup interface {
4+
Add(delta int)
5+
Done()
6+
Wait(ctx Context)
7+
}
8+
9+
type waitGroup struct {
10+
n int
11+
f Future
12+
waiting bool
13+
}
14+
15+
func NewWaitGroup() WaitGroup {
16+
return &waitGroup{
17+
f: NewFuture(),
18+
}
19+
}
20+
21+
func (wg *waitGroup) Wait(ctx Context) {
22+
wg.waiting = true
23+
24+
if err := wg.f.Get(ctx, nil); err != nil {
25+
panic(err)
26+
}
27+
}
28+
29+
func (wg *waitGroup) Add(delta int) {
30+
wg.n += delta
31+
32+
if wg.n < 0 {
33+
panic("negative WaitGroup counter")
34+
}
35+
36+
if wg.waiting && delta > 0 && wg.n == delta {
37+
panic("WaitGroup misuse: Add called concurrently with Wait")
38+
}
39+
40+
if wg.n > 0 || !wg.waiting {
41+
return
42+
}
43+
44+
if wg.n == 0 {
45+
wg.f.Set(nil, nil)
46+
}
47+
}
48+
49+
func (wg *waitGroup) Done() {
50+
wg.Add(-1)
51+
}

internal/sync/waitgroup_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package sync
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func Test_WaitGroup_PanicsForInvalidCounters(t *testing.T) {
10+
wg := NewWaitGroup()
11+
12+
require.PanicsWithValue(t, "negative WaitGroup counter", func() {
13+
wg.Add(-2)
14+
})
15+
}
16+
17+
func Test_WaitGroup_Blocks(t *testing.T) {
18+
s := NewScheduler()
19+
ctx := Background()
20+
21+
wg := NewWaitGroup()
22+
wg.Add(2)
23+
24+
s.NewCoroutine(ctx, func(ctx Context) error {
25+
wg.Wait(ctx)
26+
27+
return nil
28+
})
29+
30+
s.Execute(ctx)
31+
require.Equal(t, 1, s.RunningCoroutines())
32+
33+
s.NewCoroutine(ctx, func(ctx Context) error {
34+
wg.Done()
35+
36+
return nil
37+
})
38+
39+
s.Execute(ctx)
40+
require.Equal(t, 1, s.RunningCoroutines())
41+
42+
s.NewCoroutine(ctx, func(ctx Context) error {
43+
wg.Done()
44+
45+
return nil
46+
})
47+
48+
s.Execute(ctx)
49+
require.Equal(t, 0, s.RunningCoroutines())
50+
}

workflow/sync.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ import (
77
type Future = sync.Future
88
type Channel = sync.Channel
99
type Context = sync.Context
10+
type WaitGroup = sync.WaitGroup
1011

1112
var Canceled = sync.Canceled

workflow/workflow.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,7 @@ func NewChannel() Channel {
9090
func NewBufferedChannel(size int) Channel {
9191
return sync.NewBufferedChannel(size)
9292
}
93+
94+
func NewWaitGroup() WaitGroup {
95+
return sync.NewWaitGroup()
96+
}

0 commit comments

Comments
 (0)