Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions internal/batch/batch_future.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package batch

import (
"fmt"
"reflect"

"go.uber.org/multierr"

"go.uber.org/cadence/internal"
)

// BatchFuture is an implementation of public BatchFuture interface.
type BatchFuture struct {
futures []internal.Future
settables []internal.Settable
factories []func(ctx internal.Context) internal.Future
batchSize int

// state
wg internal.WaitGroup
}

func NewBatchFuture(ctx internal.Context, batchSize int, factories []func(ctx internal.Context) internal.Future) (*BatchFuture, error) {
var futures []internal.Future
var settables []internal.Settable
for range factories {
future, settable := internal.NewFuture(ctx)
futures = append(futures, future)
settables = append(settables, settable)
}

batchFuture := &BatchFuture{
futures: futures,
settables: settables,
factories: factories,
batchSize: batchSize,

wg: internal.NewWaitGroup(ctx),
}
batchFuture.start(ctx)
return batchFuture, nil
}

func (b *BatchFuture) GetFutures() []internal.Future {
return b.futures
}

func (b *BatchFuture) start(ctx internal.Context) {

buffered := internal.NewBufferedChannel(ctx, b.batchSize) // buffered channel to limit the number of concurrent futures
channel := internal.NewNamedChannel(ctx, "batch-future-channel")
b.wg.Add(1)
internal.GoNamed(ctx, "batch-future-submitter", func(ctx internal.Context) {
defer b.wg.Done()

for i := range b.factories {
buffered.Send(ctx, nil)
channel.Send(ctx, i)
}
channel.Close()
})

b.wg.Add(1)
internal.GoNamed(ctx, "batch-future-processor", func(ctx internal.Context) {
defer b.wg.Done()

wgForFutures := internal.NewWaitGroup(ctx)

var idx int
for channel.Receive(ctx, &idx) {
idx := idx

wgForFutures.Add(1)
internal.GoNamed(ctx, fmt.Sprintf("batch-future-processor-one-future-%d", idx), func(ctx internal.Context) {
defer wgForFutures.Done()

// fork a future and chain it to the processed future for user to get the result
f := b.factories[idx](ctx)
b.settables[idx].Chain(f)

// error handling is not needed here because the result is chained to the settable
f.Get(ctx, nil)
buffered.Receive(ctx, nil)
})
}
wgForFutures.Wait(ctx)
})
}

func (b *BatchFuture) IsReady() bool {
for _, future := range b.futures {
if !future.IsReady() {
return false
}

Check warning on line 94 in internal/batch/batch_future.go

View check run for this annotation

Codecov / codecov/patch

internal/batch/batch_future.go#L90-L94

Added lines #L90 - L94 were not covered by tests
}
return true

Check warning on line 96 in internal/batch/batch_future.go

View check run for this annotation

Codecov / codecov/patch

internal/batch/batch_future.go#L96

Added line #L96 was not covered by tests
}

// Get assigns the result of the futures to the valuePtr.
// NOTE: valuePtr must be a pointer to a slice, or nil.
// If valuePtr is a pointer to a slice, the slice will be resized to the length of the futures. Each element of the slice will be assigned with the underlying Future.Get() and thus behaves the same way.
// If valuePtr is nil, no assignment will be made.
// If error occurs, values will be set on successful futures and the errors of failed futures will be returned.
func (b *BatchFuture) Get(ctx internal.Context, valuePtr interface{}) error {
// No assignment if valuePtr is nil
if valuePtr == nil {
b.wg.Wait(ctx)
var errs error
for i := range b.futures {
errs = multierr.Append(errs, b.futures[i].Get(ctx, nil))
}
return errs
}

v := reflect.ValueOf(valuePtr)
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice {
return fmt.Errorf("valuePtr must be a pointer to a slice, got %v", v.Kind())
}

Check warning on line 118 in internal/batch/batch_future.go

View check run for this annotation

Codecov / codecov/patch

internal/batch/batch_future.go#L117-L118

Added lines #L117 - L118 were not covered by tests

// resize the slice to the length of the futures
slice := v.Elem()
if slice.Cap() < len(b.futures) {
slice.Grow(len(b.futures) - slice.Cap())
}
slice.SetLen(len(b.futures))

// wait for all futures to be ready
b.wg.Wait(ctx)

// loop through all elements of valuePtr
var errs error
for i := range b.futures {
e := b.futures[i].Get(ctx, slice.Index(i).Addr().Interface())
errs = multierr.Append(errs, e)
}

return errs
}
Loading