Skip to content

Commit 0ff4334

Browse files
committed
update with NewBatchFuture api
1 parent cdd3d2e commit 0ff4334

File tree

8 files changed

+390
-263
lines changed

8 files changed

+390
-263
lines changed
24.6 MB
Binary file not shown.

internal/batch/batch_future.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package batch
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
7+
"go.uber.org/multierr"
8+
9+
"go.uber.org/cadence/internal"
10+
)
11+
12+
type BatchFuture interface {
13+
internal.Future
14+
GetFutures() []internal.Future
15+
}
16+
17+
type batchFutureImpl struct {
18+
futures []internal.Future
19+
settables []internal.Settable
20+
factories []func(ctx internal.Context) internal.Future
21+
batchSize int
22+
23+
// state
24+
wg internal.WaitGroup
25+
}
26+
27+
func NewBatchFuture(ctx internal.Context, batchSize int, factories []func(ctx internal.Context) internal.Future) (BatchFuture, error) {
28+
var futures []internal.Future
29+
var settables []internal.Settable
30+
for range factories {
31+
future, settable := internal.NewFuture(ctx)
32+
futures = append(futures, future)
33+
settables = append(settables, settable)
34+
}
35+
36+
batchFuture := &batchFutureImpl{
37+
futures: futures,
38+
settables: settables,
39+
factories: factories,
40+
batchSize: batchSize,
41+
42+
wg: internal.NewWaitGroup(ctx),
43+
}
44+
batchFuture.start(ctx)
45+
return batchFuture, nil
46+
}
47+
48+
func (b *batchFutureImpl) GetFutures() []internal.Future {
49+
return b.futures
50+
}
51+
52+
func (b *batchFutureImpl) start(ctx internal.Context) {
53+
54+
buffered := internal.NewBufferedChannel(ctx, b.batchSize) // buffered channel to limit the number of concurrent futures
55+
channel := internal.NewNamedChannel(ctx, "batch-future-channel")
56+
b.wg.Add(1)
57+
internal.GoNamed(ctx, "batch-future-submitter", func(ctx internal.Context) {
58+
defer b.wg.Done()
59+
60+
for i := range b.factories {
61+
buffered.Send(ctx, nil)
62+
channel.Send(ctx, i)
63+
}
64+
channel.Close()
65+
})
66+
67+
b.wg.Add(1)
68+
internal.GoNamed(ctx, "batch-future-processor", func(ctx internal.Context) {
69+
defer b.wg.Done()
70+
71+
wgForFutures := internal.NewWaitGroup(ctx)
72+
73+
var idx int
74+
for channel.Receive(ctx, &idx) {
75+
idx := idx
76+
77+
wgForFutures.Add(1)
78+
internal.GoNamed(ctx, "batch-future-processor-one-future", func(ctx internal.Context) {
79+
defer wgForFutures.Done()
80+
81+
// fork a future and chain it to the processed future for user to get the result
82+
f := b.factories[idx](ctx)
83+
b.settables[idx].Chain(f)
84+
85+
// error handling is not needed here because the result is chained to the settable
86+
f.Get(ctx, nil)
87+
buffered.Receive(ctx, nil)
88+
})
89+
}
90+
wgForFutures.Wait(ctx)
91+
})
92+
}
93+
94+
func (b *batchFutureImpl) IsReady() bool {
95+
for _, future := range b.futures {
96+
if !future.IsReady() {
97+
return false
98+
}
99+
}
100+
return true
101+
}
102+
103+
func (b *batchFutureImpl) Get(ctx internal.Context, valuePtr interface{}) error {
104+
// ensure valuePtr is a slice
105+
var sliceValue reflect.Value
106+
if valuePtr != nil {
107+
108+
switch v := reflect.ValueOf(valuePtr); v.Kind() {
109+
case reflect.Ptr:
110+
if v.Elem().Kind() != reflect.Slice {
111+
return fmt.Errorf("valuePtr must be a pointer to a slice, got %v", v)
112+
}
113+
sliceValue = v.Elem()
114+
case reflect.Slice:
115+
sliceValue = v
116+
default:
117+
return fmt.Errorf("valuePtr must be a slice or a pointer to a slice, got %v", v.Kind())
118+
}
119+
// ensure slice size is the same as the number of futures
120+
if sliceValue.Len() != len(b.futures) {
121+
return fmt.Errorf("slice size must be the same as the number of futures, got %d, expected %d", sliceValue.Len(), len(b.futures))
122+
}
123+
}
124+
125+
// wait for all futures to be ready
126+
b.wg.Wait(ctx)
127+
128+
// loop through all elements of valuePtr
129+
var errs error
130+
for i := range b.futures {
131+
if valuePtr == nil {
132+
errs = multierr.Append(errs, b.futures[i].Get(ctx, nil))
133+
} else {
134+
value := sliceValue.Index(i)
135+
if value.Kind() != reflect.Ptr {
136+
value = value.Addr()
137+
}
138+
// if value is nil, initialize it
139+
if value.IsNil() {
140+
value.Set(reflect.New(value.Type().Elem()))
141+
}
142+
143+
e := b.futures[i].Get(ctx, value.Interface())
144+
errs = multierr.Append(errs, e)
145+
}
146+
}
147+
148+
return errs
149+
}
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
package batch
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"math/rand"
8+
"reflect"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
14+
"go.uber.org/cadence/internal"
15+
"go.uber.org/cadence/testsuite"
16+
"go.uber.org/multierr"
17+
)
18+
19+
type batchWorkflowInput struct {
20+
Concurrency int
21+
TotalSize int
22+
}
23+
24+
func batchWorkflow(ctx internal.Context, input batchWorkflowInput) ([]int, error) {
25+
factories := make([]func(ctx internal.Context) internal.Future, input.TotalSize)
26+
for i := 0; i < input.TotalSize; i++ {
27+
i := i
28+
factories[i] = func(ctx internal.Context) internal.Future {
29+
aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{
30+
ScheduleToStartTimeout: time.Second * 10,
31+
StartToCloseTimeout: time.Second * 10,
32+
})
33+
return internal.ExecuteActivity(aCtx, batchActivity, i)
34+
}
35+
}
36+
37+
batchFuture, err := NewBatchFuture(ctx, input.Concurrency, factories)
38+
if err != nil {
39+
return nil, err
40+
}
41+
42+
result := make([]int, input.TotalSize)
43+
err = batchFuture.Get(ctx, &result)
44+
return result, err
45+
}
46+
47+
func batchWorkflowUsingFutures(ctx internal.Context, input batchWorkflowInput) ([]int, error) {
48+
factories := make([]func(ctx internal.Context) internal.Future, input.TotalSize)
49+
for i := 0; i < input.TotalSize; i++ {
50+
i := i
51+
factories[i] = func(ctx internal.Context) internal.Future {
52+
aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{
53+
ScheduleToStartTimeout: time.Second * 10,
54+
StartToCloseTimeout: time.Second * 10,
55+
})
56+
return internal.ExecuteActivity(aCtx, batchActivity, i)
57+
}
58+
}
59+
60+
batchFuture, err := NewBatchFuture(ctx, input.Concurrency, factories)
61+
if err != nil {
62+
return nil, err
63+
}
64+
result := make([]int, input.TotalSize)
65+
66+
for i, f := range batchFuture.GetFutures() {
67+
err = f.Get(ctx, &result[i])
68+
if err != nil {
69+
return nil, err
70+
}
71+
}
72+
73+
return result, err
74+
}
75+
76+
func batchActivity(ctx context.Context, taskID int) (int, error) {
77+
select {
78+
case <-ctx.Done():
79+
return taskID, fmt.Errorf("batch activity %d failed: %w", taskID, ctx.Err())
80+
case <-time.After(time.Duration(rand.Int63n(100))*time.Millisecond + 900*time.Millisecond):
81+
return taskID, nil
82+
}
83+
}
84+
85+
func Test_BatchWorkflow(t *testing.T) {
86+
testSuite := &testsuite.WorkflowTestSuite{}
87+
env := testSuite.NewTestWorkflowEnvironment()
88+
89+
env.RegisterWorkflow(batchWorkflow)
90+
env.RegisterActivity(batchActivity)
91+
92+
totalSize := 5
93+
concurrency := 2
94+
go func() {
95+
env.ExecuteWorkflow(batchWorkflow, batchWorkflowInput{
96+
Concurrency: concurrency,
97+
TotalSize: totalSize,
98+
})
99+
}()
100+
101+
// wait for maximum time it takes to complete the workflow (totalSize/concurrency) + 1 second
102+
assert.Eventually(t, func() bool {
103+
return env.IsWorkflowCompleted()
104+
}, time.Second*time.Duration(1+float64(totalSize)/float64(concurrency)), time.Millisecond*100)
105+
106+
assert.Nil(t, env.GetWorkflowError())
107+
var result []int
108+
assert.Nil(t, env.GetWorkflowResult(&result))
109+
var expected []int
110+
for i := 0; i < totalSize; i++ {
111+
expected = append(expected, i)
112+
}
113+
assert.Equal(t, expected, result)
114+
}
115+
116+
func Test_BatchWorkflow_Cancel(t *testing.T) {
117+
testSuite := &testsuite.WorkflowTestSuite{}
118+
env := testSuite.NewTestWorkflowEnvironment()
119+
env.RegisterWorkflow(batchWorkflow)
120+
env.RegisterActivity(batchActivity)
121+
122+
totalSize := 100
123+
concurrency := 10
124+
go func() {
125+
env.ExecuteWorkflow(batchWorkflow, batchWorkflowInput{
126+
Concurrency: concurrency,
127+
TotalSize: totalSize,
128+
})
129+
}()
130+
131+
time.Sleep(time.Second*2)
132+
env.CancelWorkflow()
133+
134+
assert.Eventually(t, func() bool {
135+
return env.IsWorkflowCompleted()
136+
}, time.Second*time.Duration(1+float64(totalSize)/float64(concurrency)), time.Millisecond*100)
137+
138+
err := env.GetWorkflowError()
139+
errs := multierr.Errors(errors.Unwrap(err))
140+
assert.Less(t, len(errs), totalSize, "expect at least some to succeed")
141+
for _, e := range errs {
142+
assert.Contains(t, e.Error(), "Canceled")
143+
}
144+
}
145+
146+
func Test_BatchWorkflowUsingFutures(t *testing.T) {
147+
testSuite := &testsuite.WorkflowTestSuite{}
148+
env := testSuite.NewTestWorkflowEnvironment()
149+
150+
env.RegisterWorkflow(batchWorkflowUsingFutures)
151+
env.RegisterActivity(batchActivity)
152+
153+
totalSize := 100
154+
concurrency := 20
155+
go func() {
156+
env.ExecuteWorkflow(batchWorkflowUsingFutures, batchWorkflowInput{
157+
Concurrency: concurrency,
158+
TotalSize: totalSize,
159+
})
160+
}()
161+
162+
// wait for maximum time it takes to complete the workflow (totalSize/concurrency) + 1 second
163+
assert.Eventually(t, func() bool {
164+
return env.IsWorkflowCompleted()
165+
}, time.Second*time.Duration(1+float64(totalSize)/float64(concurrency)), time.Millisecond*100)
166+
167+
assert.Nil(t, env.GetWorkflowError())
168+
var result []int
169+
assert.Nil(t, env.GetWorkflowResult(&result))
170+
var expected []int
171+
for i := 0; i < totalSize; i++ {
172+
expected = append(expected, i)
173+
}
174+
assert.Equal(t, expected, result)
175+
}
176+
177+
func futureTest(ctx internal.Context) error {
178+
f, s := internal.NewFuture(ctx)
179+
f2, s2 := internal.NewFuture(ctx)
180+
s2.Chain(f)
181+
182+
wg := internal.NewWaitGroup(ctx)
183+
wg.Add(1)
184+
internal.GoNamed(ctx, "future-test", func(ctx internal.Context) {
185+
defer wg.Done()
186+
internal.Sleep(ctx, time.Second*10)
187+
s.Set(1, nil)
188+
})
189+
190+
err := f2.Get(ctx, nil)
191+
if err != nil {
192+
return err
193+
}
194+
195+
err = f.Get(ctx, nil)
196+
if err != nil {
197+
return err
198+
}
199+
200+
wg.Wait(ctx)
201+
return err
202+
}
203+
204+
func Test_Futures(t *testing.T) {
205+
testSuite := &testsuite.WorkflowTestSuite{}
206+
env := testSuite.NewTestWorkflowEnvironment()
207+
208+
env.RegisterWorkflow(futureTest)
209+
210+
env.ExecuteWorkflow(futureTest)
211+
}
212+
213+
func Test_valuePtr(t *testing.T) {
214+
slices := make([]int, 10)
215+
slicePtr := &slices
216+
217+
fmt.Println(reflect.ValueOf(slicePtr).Elem().Len())
218+
}

0 commit comments

Comments
 (0)