diff --git a/internal/batch/batch_future.go b/internal/batch/batch_future.go new file mode 100644 index 000000000..657bc3924 --- /dev/null +++ b/internal/batch/batch_future.go @@ -0,0 +1,138 @@ +package batch + +import ( + "fmt" + "reflect" + + "go.uber.org/multierr" + + "go.uber.org/cadence/internal" +) + +// BatchFuture is an implementation of public BatchFuture interface. +type BatchFuture struct { + futures []internal.Future + settables []internal.Settable + factories []func(ctx internal.Context) internal.Future + batchSize int + + // state + wg internal.WaitGroup +} + +func NewBatchFuture(ctx internal.Context, batchSize int, factories []func(ctx internal.Context) internal.Future) (*BatchFuture, error) { + var futures []internal.Future + var settables []internal.Settable + for range factories { + future, settable := internal.NewFuture(ctx) + futures = append(futures, future) + settables = append(settables, settable) + } + + batchFuture := &BatchFuture{ + futures: futures, + settables: settables, + factories: factories, + batchSize: batchSize, + + wg: internal.NewWaitGroup(ctx), + } + batchFuture.start(ctx) + return batchFuture, nil +} + +func (b *BatchFuture) GetFutures() []internal.Future { + return b.futures +} + +func (b *BatchFuture) start(ctx internal.Context) { + + semaphore := internal.NewBufferedChannel(ctx, b.batchSize) // buffered workChan to limit the number of concurrent futures + workChan := internal.NewNamedChannel(ctx, "batch-future-channel") + b.wg.Add(1) + internal.GoNamed(ctx, "batch-future-submitter", func(ctx internal.Context) { + defer b.wg.Done() + + for i := range b.factories { + semaphore.Send(ctx, nil) + workChan.Send(ctx, i) + } + workChan.Close() + }) + + b.wg.Add(1) + internal.GoNamed(ctx, "batch-future-processor", func(ctx internal.Context) { + defer b.wg.Done() + + wgForFutures := internal.NewWaitGroup(ctx) + + var idx int + for workChan.Receive(ctx, &idx) { + idx := idx + + wgForFutures.Add(1) + internal.GoNamed(ctx, fmt.Sprintf("batch-future-processor-one-future-%d", idx), func(ctx internal.Context) { + defer wgForFutures.Done() + + // fork a future and chain it to the processed future for user to get the result + f := b.factories[idx](ctx) + b.settables[idx].Chain(f) + + // error handling is not needed here because the result is chained to the settable + f.Get(ctx, nil) + semaphore.Receive(ctx, nil) + }) + } + wgForFutures.Wait(ctx) + }) +} + +func (b *BatchFuture) IsReady() bool { + for _, future := range b.futures { + if !future.IsReady() { + return false + } + } + return true +} + +// Get assigns the result of the futures to the valuePtr. +// NOTE: valuePtr must be a pointer to a slice, or nil. +// If valuePtr is a pointer to a slice, the slice will be resized to the length of the futures. Each element of the slice will be assigned with the underlying Future.Get() and thus behaves the same way. +// If valuePtr is nil, no assignment will be made. +// If error occurs, values will be set on successful futures and the errors of failed futures will be returned. +func (b *BatchFuture) Get(ctx internal.Context, valuePtr interface{}) error { + // No assignment if valuePtr is nil + if valuePtr == nil { + b.wg.Wait(ctx) + var errs error + for i := range b.futures { + errs = multierr.Append(errs, b.futures[i].Get(ctx, nil)) + } + return errs + } + + v := reflect.ValueOf(valuePtr) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice { + return fmt.Errorf("valuePtr must be a pointer to a slice, got %v", v.Kind()) + } + + // resize the slice to the length of the futures + slice := v.Elem() + if slice.Cap() < len(b.futures) { + slice.Grow(len(b.futures) - slice.Cap()) + } + slice.SetLen(len(b.futures)) + + // wait for all futures to be ready + b.wg.Wait(ctx) + + // loop through all elements of valuePtr + var errs error + for i := range b.futures { + e := b.futures[i].Get(ctx, slice.Index(i).Addr().Interface()) + errs = multierr.Append(errs, e) + } + + return errs +} diff --git a/internal/batch/batch_future_test.go b/internal/batch/batch_future_test.go new file mode 100644 index 000000000..540df3c1f --- /dev/null +++ b/internal/batch/batch_future_test.go @@ -0,0 +1,293 @@ +package batch + +import ( + "context" + "errors" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "go.uber.org/multierr" + + "go.uber.org/cadence/internal" + "go.uber.org/cadence/testsuite" +) + +// TODO: add clock simulation to speed up the test + +type batchWorkflowInput struct { + Concurrency int + TotalSize int +} + +func batchWorkflow(ctx internal.Context, input batchWorkflowInput) ([]int, error) { + factories := make([]func(ctx internal.Context) internal.Future, input.TotalSize) + for i := 0; i < input.TotalSize; i++ { + i := i + factories[i] = func(ctx internal.Context) internal.Future { + aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{ + ScheduleToStartTimeout: time.Second * 10, + StartToCloseTimeout: time.Second * 10, + }) + return internal.ExecuteActivity(aCtx, batchActivity, i) + } + } + + batchFuture, err := NewBatchFuture(ctx, input.Concurrency, factories) + if err != nil { + return nil, err + } + + result := make([]int, input.TotalSize) + err = batchFuture.Get(ctx, &result) + return result, err +} + +func batchWorkflowUsingFutures(ctx internal.Context, input batchWorkflowInput) ([]int, error) { + factories := make([]func(ctx internal.Context) internal.Future, input.TotalSize) + for i := 0; i < input.TotalSize; i++ { + i := i + factories[i] = func(ctx internal.Context) internal.Future { + aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{ + ScheduleToStartTimeout: time.Second * 10, + StartToCloseTimeout: time.Second * 10, + }) + return internal.ExecuteActivity(aCtx, batchActivity, i) + } + } + + batchFuture, err := NewBatchFuture(ctx, input.Concurrency, factories) + if err != nil { + return nil, err + } + result := make([]int, input.TotalSize) + + for i, f := range batchFuture.GetFutures() { + err = f.Get(ctx, &result[i]) + if err != nil { + return nil, err + } + } + + return result, err +} + +func batchActivity(ctx context.Context, taskID int) (int, error) { + select { + case <-ctx.Done(): + return taskID, fmt.Errorf("batch activity %d failed: %w", taskID, ctx.Err()) + case <-time.After(time.Duration(rand.Int63n(100))*time.Millisecond + 900*time.Millisecond): + return taskID, nil + } +} + +func Test_BatchWorkflow(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + env.RegisterWorkflow(batchWorkflow) + env.RegisterActivity(batchActivity) + + totalSize := 100 + concurrency := 20 + + startTime := time.Now() + env.ExecuteWorkflow(batchWorkflow, batchWorkflowInput{ + Concurrency: concurrency, + TotalSize: totalSize, + }) + + assert.Less(t, time.Since(startTime), time.Second*time.Duration(float64(totalSize)/float64(concurrency))) + assert.True(t, env.IsWorkflowCompleted()) + + assert.Nil(t, env.GetWorkflowError()) + var result []int + assert.Nil(t, env.GetWorkflowResult(&result)) + var expected []int + for i := 0; i < totalSize; i++ { + expected = append(expected, i) + } + assert.Equal(t, expected, result) +} + +func Test_BatchWorkflow_Cancel(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(batchWorkflow) + env.RegisterActivity(batchActivity) + + totalSize := 100 + concurrency := 20 + + totalExpectedTime := time.Second * time.Duration(1+totalSize/concurrency) + env.ExecuteWorkflow(batchWorkflow, batchWorkflowInput{ + Concurrency: concurrency, + TotalSize: totalSize, + }) + + env.RegisterDelayedCallback(func() { + env.CancelWorkflow() + }, totalExpectedTime/2) + + assert.True(t, env.IsWorkflowCompleted()) + + err := env.GetWorkflowError() + errs := multierr.Errors(errors.Unwrap(err)) + assert.Less(t, len(errs), totalSize, "expect at least some to succeed") + for _, e := range errs { + assert.Contains(t, e.Error(), "Canceled") + } +} + +func Test_BatchWorkflowUsingFutures(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + + env.RegisterWorkflow(batchWorkflowUsingFutures) + env.RegisterActivity(batchActivity) + + totalSize := 100 + concurrency := 20 + + startTime := time.Now() + env.ExecuteWorkflow(batchWorkflowUsingFutures, batchWorkflowInput{ + Concurrency: concurrency, + TotalSize: totalSize, + }) + assert.Less(t, time.Since(startTime), time.Second*time.Duration(float64(totalSize)/float64(concurrency))) + assert.True(t, env.IsWorkflowCompleted()) + + assert.Nil(t, env.GetWorkflowError()) + var result []int + assert.Nil(t, env.GetWorkflowResult(&result)) + var expected []int + for i := 0; i < totalSize; i++ { + expected = append(expected, i) + } + assert.Equal(t, expected, result) +} + +func batchWorkflowAssignWithSlice(ctx internal.Context) ([]int, error) { + totalSize := 5 + concurrency := 2 + factories := make([]func(ctx internal.Context) internal.Future, totalSize) + for i := 0; i < totalSize; i++ { + i := i + factories[i] = func(ctx internal.Context) internal.Future { + aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{ + ScheduleToStartTimeout: time.Second * 10, + StartToCloseTimeout: time.Second * 10, + }) + return internal.ExecuteActivity(aCtx, batchActivity, i) + } + } + + batchFuture, err := NewBatchFuture(ctx, concurrency, factories) + if err != nil { + return nil, err + } + + var valuePtr []int + if err := batchFuture.Get(ctx, &valuePtr); err != nil { + return nil, err + } + return valuePtr, nil +} + +func batchWorkflowAssignWithSliceOfPointers(ctx internal.Context) ([]int, error) { + totalSize := 5 + concurrency := 2 + factories := make([]func(ctx internal.Context) internal.Future, totalSize) + for i := 0; i < totalSize; i++ { + i := i + factories[i] = func(ctx internal.Context) internal.Future { + aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{ + ScheduleToStartTimeout: time.Second * 10, + StartToCloseTimeout: time.Second * 10, + }) + return internal.ExecuteActivity(aCtx, batchActivity, i) + } + } + batchFuture, err := NewBatchFuture(ctx, concurrency, factories) + if err != nil { + return nil, err + } + var valuePtr []*int + if err := batchFuture.Get(ctx, &valuePtr); err != nil { + return nil, err + } + + var result []int + for _, v := range valuePtr { + result = append(result, *v) + } + return result, nil +} + +func batchWorkflowAssignWithNil(ctx internal.Context) ([]int, error) { + totalSize := 5 + concurrency := 2 + factories := make([]func(ctx internal.Context) internal.Future, totalSize) + for i := 0; i < totalSize; i++ { + i := i + factories[i] = func(ctx internal.Context) internal.Future { + aCtx := internal.WithActivityOptions(ctx, internal.ActivityOptions{ + ScheduleToStartTimeout: time.Second * 10, + StartToCloseTimeout: time.Second * 10, + }) + return internal.ExecuteActivity(aCtx, batchActivity, i) + } + } + + batchFuture, err := NewBatchFuture(ctx, concurrency, factories) + if err != nil { + return nil, err + } + + if err := batchFuture.Get(ctx, nil); err != nil { + return nil, err + } + return nil, nil +} + +func Test_BatchFuture_Get(t *testing.T) { + tests := []struct { + name string + workflow func(ctx internal.Context) ([]int, error) + want interface{} + }{ + { + name: "success with nil slice", + workflow: batchWorkflowAssignWithSlice, + want: []int{0, 1, 2, 3, 4}, + }, + { + name: "success with non-nil slice", + workflow: batchWorkflowAssignWithSliceOfPointers, + want: []int{0, 1, 2, 3, 4}, + }, + { + name: "success with nil", + workflow: batchWorkflowAssignWithNil, + want: []int(nil), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := &testsuite.WorkflowTestSuite{} + env := testSuite.NewTestWorkflowEnvironment() + env.RegisterWorkflow(tt.workflow) + env.RegisterActivity(batchActivity) + env.ExecuteWorkflow(tt.workflow) + assert.True(t, env.IsWorkflowCompleted()) + assert.Nil(t, env.GetWorkflowError()) + var result []int + assert.Nil(t, env.GetWorkflowResult(&result)) + assert.Equal(t, tt.want, result) + }) + } +} diff --git a/x/batch.go b/x/batch.go new file mode 100644 index 000000000..69d5fc7b2 --- /dev/null +++ b/x/batch.go @@ -0,0 +1,40 @@ +package x + +import ( + "go.uber.org/cadence/internal/batch" + "go.uber.org/cadence/workflow" +) + +var _ workflow.Future = (BatchFuture)(nil) // to ensure it's compatible + +// BatchFuture wraps a collection of futures, and provides some convenience methods for dealing with them in bulk. +type BatchFuture interface { + // IsReady returns true when all wrapped futures return true from their IsReady + IsReady() bool + // Get acts like workflow.Future.Get, but it reads out all wrapped futures into the provided slice pointer. + // You MUST either + // 1. provide a pointer to a slice as the value-pointer here, but the slice itself can be nil - it will be allocated and/or resized to fit if needed. + // 2. provide a nil to indicate that you don't want to collect the results. + // + // This call will wait for all futures to resolve, and will then write all results to the output slice in the same order as the input. + // + // Any errors encountered are merged with go.uber.org/multierr, so single errors are + // exposed normally, but multiple ones are bundled in the same way as errors.Join. + // For consistency when checking individual errors, consider using `multierr.Errors(err)` in all cases, + // or `GetFutures()[i].Get(ctx, nil)` to get the original errors at each index. + Get(ctx workflow.Context, valuePtr interface{}) error + // GetFutures returns a slice of all the wrapped futures. + // This slice MUST NOT be modified, but the individual futures can be used normally. + GetFutures() []workflow.Future +} + +// NewBatchFuture creates a bounded-concurrency helper for doing bulk work in your workflow. +// It does not reduce the amount of history your workflow stores, so any event-count +// or history-size limits are unaffected - you must still be cautious about the total +// amount of work you do in any workflow. +// +// When NewBatchFuture is called, futures created by the factories will be started concurrently until the concurrency limit (batchSize) is reached. +// The remaining factories will be queued and started as previous futures complete, maintaining the specified concurrency level. +func NewBatchFuture(ctx workflow.Context, batchSize int, factories []func(ctx workflow.Context) workflow.Future) (BatchFuture, error) { + return batch.NewBatchFuture(ctx, batchSize, factories) +} diff --git a/x/doc.go b/x/doc.go new file mode 100644 index 000000000..c4bf75f07 --- /dev/null +++ b/x/doc.go @@ -0,0 +1,2 @@ +// Package x is an experimental package for early-stage features. The API here is not stable and may change. +package x