Skip to content

Commit 32d7292

Browse files
authored
fix selector (#416)
* fix selector * add unit test for selector * gofmt
1 parent 70d1c0c commit 32d7292

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

internal/internal_coroutines_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,51 @@ func TestBlockingSelectAsyncSend(t *testing.T) {
353353
require.EqualValues(t, expected, history)
354354
}
355355

356+
func TestBlockingSelectAsyncSend2(t *testing.T) {
357+
var history []string
358+
d := newDispatcher(background, func(ctx Context) {
359+
c1 := NewBufferedChannel(ctx, 100)
360+
c2 := NewBufferedChannel(ctx, 100)
361+
s := NewSelector(ctx)
362+
s.
363+
AddReceive(c1, func(c Channel, more bool) {
364+
assert.True(t, more)
365+
var v string
366+
c.Receive(ctx, &v)
367+
history = append(history, fmt.Sprintf("c1-%v", v))
368+
}).
369+
AddReceive(c2, func(c Channel, more bool) {
370+
assert.True(t, more)
371+
var v string
372+
c.Receive(ctx, &v)
373+
history = append(history, fmt.Sprintf("c2-%v", v))
374+
})
375+
376+
history = append(history, "send-s2")
377+
c2.SendAsync("s2")
378+
history = append(history, "select-0")
379+
s.Select(ctx)
380+
history = append(history, "send-s1")
381+
c1.SendAsync("s1")
382+
history = append(history, "select-1")
383+
s.Select(ctx)
384+
history = append(history, "done")
385+
})
386+
d.ExecuteUntilAllBlocked()
387+
require.True(t, d.IsDone(), strings.Join(history, "\n"))
388+
389+
expected := []string{
390+
"send-s2",
391+
"select-0",
392+
"c2-s2",
393+
"send-s1",
394+
"select-1",
395+
"c1-s1",
396+
"done",
397+
}
398+
require.EqualValues(t, expected, history)
399+
}
400+
356401
func TestSendSelect(t *testing.T) {
357402
var history []string
358403
d := newDispatcher(background, func(ctx Context) {

internal/internal_workflow.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,11 @@ func (s *selectorImpl) Select(ctx Context) {
880880
}
881881
v, ok, more := pair.channel.receiveAsyncImpl(callback)
882882
if ok || !more {
883+
// Select() returns in this case/branch. The callback won't be called for this case. However, callback
884+
// will be called for previous cases/branches. We should set readyBranch so that when other case/branch
885+
// become ready they won't consume the value for this Select() call.
886+
readyBranch = func() {
887+
}
883888
c.recValue = &v
884889
f(c, more)
885890
return
@@ -900,6 +905,11 @@ func (s *selectorImpl) Select(ctx Context) {
900905
}
901906
ok := pair.channel.sendAsyncImpl(*pair.sendValue, p)
902907
if ok {
908+
// Select() returns in this case/branch. The callback won't be called for this case. However, callback
909+
// will be called for previous cases/branches. We should set readyBranch so that when other case/branch
910+
// become ready they won't consume the value for this Select() call.
911+
readyBranch = func() {
912+
}
903913
f()
904914
return
905915
}
@@ -918,6 +928,11 @@ func (s *selectorImpl) Select(ctx Context) {
918928
}
919929
_, ok, _ := p.future.GetAsync(callback)
920930
if ok {
931+
// Select() returns in this case/branch. The callback won't be called for this case. However, callback
932+
// will be called for previous cases/branches. We should set readyBranch so that when other case/branch
933+
// become ready they won't consume the value for this Select() call.
934+
readyBranch = func() {
935+
}
921936
p.futureFunc = nil
922937
f(p.future)
923938
return

internal/internal_workflow_testsuite_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,3 +1536,69 @@ func (s *WorkflowTestSuiteUnitTest) Test_WorkflowIDReusePolicy() {
15361536
s.NoError(env.GetWorkflowResult(&actualResult))
15371537
s.Equal("hello_world", actualResult)
15381538
}
1539+
1540+
func (s *WorkflowTestSuiteUnitTest) Test_Channel() {
1541+
workflowFn := func(ctx Context) error {
1542+
1543+
signalCh := GetSignalChannel(ctx, "test-signal")
1544+
doneCh := NewBufferedChannel(ctx, 100)
1545+
selector := NewSelector(ctx)
1546+
1547+
selector.AddReceive(signalCh, func(c Channel, more bool) {
1548+
}).AddReceive(doneCh, func(c Channel, more bool) {
1549+
var doneSignal string
1550+
c.Receive(ctx, &doneSignal)
1551+
})
1552+
1553+
fanoutChs := []Channel{NewBufferedChannel(ctx, 100), NewBufferedChannel(ctx, 100)}
1554+
1555+
processedCount := 0
1556+
runningCount := 0
1557+
1558+
mainLoop:
1559+
for {
1560+
selector.Select(ctx)
1561+
var signal string
1562+
if !signalCh.ReceiveAsync(&signal) {
1563+
if runningCount > 0 {
1564+
continue mainLoop
1565+
}
1566+
1567+
if processedCount < 4 {
1568+
continue mainLoop
1569+
}
1570+
1571+
// continue as new
1572+
return NewContinueAsNewError(ctx, "this-workflow-fn")
1573+
}
1574+
1575+
for i := range fanoutChs {
1576+
ch := fanoutChs[i]
1577+
ch.SendAsync(signal)
1578+
processedCount++
1579+
runningCount++
1580+
Go(ctx, func(ctx Context) {
1581+
doneCh.SendAsync("done")
1582+
runningCount--
1583+
})
1584+
}
1585+
}
1586+
1587+
return nil
1588+
}
1589+
1590+
RegisterWorkflow(workflowFn)
1591+
env := s.NewTestWorkflowEnvironment()
1592+
1593+
env.RegisterDelayedCallback(func() {
1594+
env.SignalWorkflow("test-signal", "s1")
1595+
env.SignalWorkflow("test-signal", "s2")
1596+
}, time.Minute)
1597+
1598+
env.ExecuteWorkflow(workflowFn)
1599+
1600+
s.True(env.IsWorkflowCompleted())
1601+
s.Error(env.GetWorkflowError())
1602+
_, ok := env.GetWorkflowError().(*ContinueAsNewError)
1603+
s.True(ok)
1604+
}

0 commit comments

Comments
 (0)