Skip to content

Commit 609f22c

Browse files
authored
Merge pull request #188 from cschleiden/fix-cmd-cancellation
Support invalid state transition for finished commands during workflow cancellation
2 parents 7be202d + a793796 commit 609f22c

File tree

8 files changed

+191
-38
lines changed

8 files changed

+191
-38
lines changed

internal/command/cancelablecommand.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ func (c *cancelableCommand) Done() {
3838
switch c.state {
3939
case CommandState_Committed, CommandState_Canceled:
4040
c.state = CommandState_Done
41+
if c.whenDone != nil {
42+
c.whenDone()
43+
}
44+
4145
default:
4246
c.invalidStateTransition(CommandState_Done)
4347
}

internal/command/command.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type Command interface {
2424
// Done marks the command as done. This transitions the state to done and indicates that the result
2525
// of this command has been applied.
2626
Done()
27+
28+
WhenDone(fn func())
2729
}
2830

2931
type CommandResult struct {
@@ -40,6 +42,8 @@ type command struct {
4042
state CommandState
4143

4244
id int64
45+
46+
whenDone func()
4347
}
4448

4549
func (c *command) ID() int64 {
@@ -67,11 +71,19 @@ func (c *command) Done() {
6771
switch c.state {
6872
case CommandState_Committed:
6973
c.state = CommandState_Done
74+
75+
if c.whenDone != nil {
76+
c.whenDone()
77+
}
7078
default:
7179
c.invalidStateTransition(CommandState_Done)
7280
}
7381
}
7482

83+
func (c *command) WhenDone(fn func()) {
84+
c.whenDone = fn
85+
}
86+
7587
func (c *command) invalidStateTransition(state CommandState) {
7688
panic(fmt.Errorf("invalid state transition for command %s: %s -> %s", c.name, c.State().String(), state.String()))
7789
}

internal/command/sideeffect.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ func (c *SideEffectCommand) Done() {
6565
switch c.state {
6666
case CommandState_Pending, CommandState_Committed:
6767
c.state = CommandState_Done
68+
if c.whenDone != nil {
69+
c.whenDone()
70+
}
6871

6972
default:
7073
c.invalidStateTransition(CommandState_Done)

internal/sync/channel.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@ type Channel[T any] interface {
1212
Close()
1313
}
1414

15+
type Receiver[T any] struct {
16+
Receive func(v T, ok bool)
17+
}
18+
1519
type ChannelInternal[T any] interface {
1620
Closed() bool
1721

1822
ReceiveNonBlocking() (v T, ok bool)
1923

2024
// AddReceiveCallback adds a callback that is called once when a value is sent to the channel. This is similar
2125
// to the blocking `Receive` method, but is not blocking a coroutine.
22-
AddReceiveCallback(cb func(v T, ok bool))
26+
AddReceiveCallback(rcb *Receiver[T])
27+
28+
RemoveReceiveCallback(rcb *Receiver[T])
2329
}
2430

2531
// Ensure channel implementation support internal interface
@@ -40,7 +46,7 @@ func NewBufferedChannel[T any](size int) Channel[T] {
4046

4147
type channel[T any] struct {
4248
c []T
43-
receivers []func(value T, ok bool)
49+
receivers []*Receiver[T]
4450
senders []func() T
4551
closed bool
4652
size int
@@ -62,7 +68,7 @@ func (c *channel[T]) Close() {
6268

6369
// Send zero value to pending receiver
6470
var v T
65-
r(v, false)
71+
r.Receive(v, false)
6672
}
6773
}
6874

@@ -119,10 +125,12 @@ func (c *channel[T]) Receive(ctx Context) (v T, ok bool) {
119125

120126
// Register handler to receive value once
121127
if !addedListener {
122-
cb := func(rv T, rok bool) {
123-
receivedValue = true
124-
v = rv
125-
ok = rok
128+
cb := &Receiver[T]{
129+
Receive: func(rv T, rok bool) {
130+
receivedValue = true
131+
v = rv
132+
ok = rok
133+
},
126134
}
127135

128136
c.receivers = append(c.receivers, cb)
@@ -176,7 +184,7 @@ func (c *channel[T]) trySend(v T) bool {
176184
c.receivers[0] = nil
177185
c.receivers = c.receivers[1:]
178186

179-
r(v, true)
187+
r.Receive(v, true)
180188

181189
return true
182190
}
@@ -223,10 +231,20 @@ func (c *channel[T]) hasCapacity() bool {
223231
return len(c.c) < c.size
224232
}
225233

226-
func (c *channel[T]) AddReceiveCallback(cb func(v T, ok bool)) {
234+
func (c *channel[T]) AddReceiveCallback(cb *Receiver[T]) {
227235
c.receivers = append(c.receivers, cb)
228236
}
229237

238+
func (c *channel[T]) RemoveReceiveCallback(cb *Receiver[T]) {
239+
for i, r := range c.receivers {
240+
if r == cb {
241+
c.receivers[i] = nil
242+
c.receivers = append(c.receivers[:i], c.receivers[i+1:]...)
243+
return
244+
}
245+
}
246+
}
247+
230248
func (c *channel[T]) Closed() bool {
231249
return c.closed
232250
}

internal/sync/channel_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,3 +413,60 @@ func Test_Channel_Buffered(t *testing.T) {
413413
})
414414
}
415415
}
416+
417+
func Test_CancellationHandler_Add(t *testing.T) {
418+
ctx, cancel := WithCancel(Background())
419+
420+
f := 1
421+
422+
c := NewChannel[int]()
423+
ic := c.(*channel[int])
424+
425+
cr := NewCoroutine(ctx, func(ctx Context) error {
426+
r := &Receiver[int]{
427+
Receive: func(_ int, ok bool) {
428+
f++
429+
},
430+
}
431+
432+
ic.AddReceiveCallback(r)
433+
434+
c.Send(ctx, 42)
435+
return nil
436+
})
437+
438+
cr.Execute()
439+
440+
cancel()
441+
442+
require.Equal(t, 2, f)
443+
}
444+
445+
func Test_CancellationHandler_Remove(t *testing.T) {
446+
ctx, cancel := WithCancel(Background())
447+
448+
f := 1
449+
450+
c := NewChannel[int]()
451+
ic := c.(*channel[int])
452+
453+
cr := NewCoroutine(ctx, func(ctx Context) error {
454+
r := &Receiver[int]{
455+
Receive: func(_ int, ok bool) {
456+
f++
457+
},
458+
}
459+
460+
ic.AddReceiveCallback(r)
461+
ic.RemoveReceiveCallback(r)
462+
463+
c.Send(ctx, 42)
464+
return nil
465+
})
466+
467+
cr.Execute()
468+
469+
cancel()
470+
471+
require.Equal(t, 1, f)
472+
}

workflow/subworkflow.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,31 @@ func createSubWorkflowInstance[TResult any](ctx sync.Context, options SubWorkflo
8585
span.Marshal(metadata)
8686

8787
cmd := command.NewScheduleSubWorkflowCommand(scheduleEventID, wfState.Instance(), options.InstanceID, name, inputs, metadata)
88+
8889
wfState.AddCommand(cmd)
8990
wfState.TrackFuture(scheduleEventID, workflowstate.AsDecodingSettable(cv, f))
9091

9192
// Check if the channel is cancelable
9293
if c, cancelable := ctx.Done().(sync.CancelChannel); cancelable {
93-
c.AddReceiveCallback(func(v struct{}, ok bool) {
94-
cmd.Cancel()
95-
if cmd.State() == command.CommandState_Canceled {
96-
// Remove the sub-workflow future from the workflow state and mark it as canceled if it hasn't already fired
97-
if fi, ok := f.(sync.FutureInternal[TResult]); ok {
98-
if !fi.Ready() {
99-
wfState.RemoveFuture(scheduleEventID)
100-
f.Set(*new(TResult), sync.Canceled)
94+
cancelReceiver := &sync.Receiver[struct{}]{
95+
Receive: func(v struct{}, ok bool) {
96+
cmd.Cancel()
97+
if cmd.State() == command.CommandState_Canceled {
98+
// Remove the sub-workflow future from the workflow state and mark it as canceled if it hasn't already fired
99+
if fi, ok := f.(sync.FutureInternal[TResult]); ok {
100+
if !fi.Ready() {
101+
wfState.RemoveFuture(scheduleEventID)
102+
f.Set(*new(TResult), sync.Canceled)
103+
}
101104
}
102105
}
103-
}
106+
},
107+
}
108+
109+
c.AddReceiveCallback(cancelReceiver)
110+
111+
cmd.WhenDone(func() {
112+
c.RemoveReceiveCallback(cancelReceiver)
104113
})
105114
}
106115

workflow/timer.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,27 @@ func ScheduleTimer(ctx Context, delay time.Duration) Future[struct{}] {
2525

2626
scheduleEventID := wfState.GetNextScheduleEventID()
2727
at := Now(ctx).Add(delay)
28+
2829
timerCmd := command.NewScheduleTimerCommand(scheduleEventID, at)
2930
wfState.AddCommand(timerCmd)
30-
3131
wfState.TrackFuture(scheduleEventID, workflowstate.AsDecodingSettable(converter.GetConverter(ctx), f))
3232

33+
cancelReceiver := &sync.Receiver[struct{}]{
34+
Receive: func(v struct{}, ok bool) {
35+
timerCmd.Cancel()
36+
37+
// Remove the timer future from the workflow state and mark it as canceled if it hasn't already fired. This is different
38+
// from subworkflow behavior, where we want to wait for the subworkflow to complete before proceeding. Here we can
39+
// continue right away.
40+
if fi, ok := f.(sync.FutureInternal[struct{}]); ok {
41+
if !fi.Ready() {
42+
wfState.RemoveFuture(scheduleEventID)
43+
f.Set(v, sync.Canceled)
44+
}
45+
}
46+
},
47+
}
48+
3349
ctx, span := workflowtracer.Tracer(ctx).Start(ctx, "ScheduleTimer",
3450
trace.WithAttributes(
3551
attribute.Int64("duration_ms", int64(delay/time.Millisecond)),
@@ -42,26 +58,10 @@ func ScheduleTimer(ctx Context, delay time.Duration) Future[struct{}] {
4258
if c, cancelable := ctx.Done().(sync.CancelChannel); cancelable {
4359
// Register a callback for when it's canceled. The only operation on the `Done` channel
4460
// is that it's closed when the context is canceled.
45-
canceled := false
46-
47-
c.AddReceiveCallback(func(v struct{}, ok bool) {
48-
// Ignore any future cancelation events for this timer
49-
if canceled {
50-
return
51-
}
52-
canceled = true
61+
c.AddReceiveCallback(cancelReceiver)
5362

54-
timerCmd.Cancel()
55-
56-
// Remove the timer future from the workflow state and mark it as canceled if it hasn't already fired. This is different
57-
// from subworkflow behavior, where we want to wait for the subworkflow to complete before proceeding. Here we can
58-
// continue right away.
59-
if fi, ok := f.(sync.FutureInternal[struct{}]); ok {
60-
if !fi.Ready() {
61-
wfState.RemoveFuture(scheduleEventID)
62-
f.Set(v, sync.Canceled)
63-
}
64-
}
63+
timerCmd.WhenDone(func() {
64+
c.RemoveReceiveCallback(cancelReceiver)
6565
})
6666
}
6767

workflow/timer_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package workflow
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/benbjohnson/clock"
8+
"github.com/cschleiden/go-workflows/internal/converter"
9+
"github.com/cschleiden/go-workflows/internal/core"
10+
"github.com/cschleiden/go-workflows/internal/logger"
11+
"github.com/cschleiden/go-workflows/internal/sync"
12+
"github.com/cschleiden/go-workflows/internal/workflowstate"
13+
"github.com/cschleiden/go-workflows/internal/workflowtracer"
14+
"github.com/stretchr/testify/require"
15+
"go.opentelemetry.io/otel/trace"
16+
)
17+
18+
func Test_Timer_Cancellation(t *testing.T) {
19+
state := workflowstate.NewWorkflowState(core.NewWorkflowInstance("a", ""), logger.NewDefaultLogger(), clock.New())
20+
21+
ctx, cancel := sync.WithCancel(sync.Background())
22+
ctx = converter.WithConverter(ctx, converter.DefaultConverter)
23+
ctx = workflowstate.WithWorkflowState(ctx, state)
24+
ctx = workflowtracer.WithWorkflowTracer(ctx, workflowtracer.New(trace.NewNoopTracerProvider().Tracer("test")))
25+
26+
c := sync.NewCoroutine(ctx, func(ctx sync.Context) error {
27+
f := ScheduleTimer(ctx, time.Second*1)
28+
f.Get(ctx)
29+
30+
// Block workflow
31+
sync.NewFuture[int]().Get(ctx)
32+
33+
return nil
34+
})
35+
c.Execute()
36+
require.False(t, c.Finished())
37+
38+
// Fire timer
39+
cmd := state.CommandByScheduleEventID(1)
40+
cmd.Commit()
41+
cmd.Done()
42+
fs, ok := state.FutureByScheduleEventID(1)
43+
require.True(t, ok)
44+
fs(nil, nil)
45+
46+
c.Execute()
47+
require.False(t, c.Finished())
48+
49+
cancel()
50+
}

0 commit comments

Comments
 (0)