Skip to content

Commit a793796

Browse files
committed
Support workflow cancellation after commands are done
1 parent 6463ca2 commit a793796

File tree

6 files changed

+107
-29
lines changed

6 files changed

+107
-29
lines changed

internal/command/cancelablecommand.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ func (c *cancelableCommand) Done() {
3838
switch c.state {
3939
case CommandState_Committed, CommandState_Canceled:
4040
c.state = CommandState_Done
41+
if c.whenDone != nil {
42+
c.whenDone()
43+
}
44+
4145
default:
4246
c.invalidStateTransition(CommandState_Done)
4347
}

internal/command/command.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ type Command interface {
2424
// Done marks the command as done. This transitions the state to done and indicates that the result
2525
// of this command has been applied.
2626
Done()
27+
28+
WhenDone(fn func())
2729
}
2830

2931
type CommandResult struct {
@@ -40,6 +42,8 @@ type command struct {
4042
state CommandState
4143

4244
id int64
45+
46+
whenDone func()
4347
}
4448

4549
func (c *command) ID() int64 {
@@ -67,11 +71,19 @@ func (c *command) Done() {
6771
switch c.state {
6872
case CommandState_Committed:
6973
c.state = CommandState_Done
74+
75+
if c.whenDone != nil {
76+
c.whenDone()
77+
}
7078
default:
7179
c.invalidStateTransition(CommandState_Done)
7280
}
7381
}
7482

83+
func (c *command) WhenDone(fn func()) {
84+
c.whenDone = fn
85+
}
86+
7587
func (c *command) invalidStateTransition(state CommandState) {
7688
panic(fmt.Errorf("invalid state transition for command %s: %s -> %s", c.name, c.State().String(), state.String()))
7789
}

internal/command/sideeffect.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ func (c *SideEffectCommand) Done() {
6565
switch c.state {
6666
case CommandState_Pending, CommandState_Committed:
6767
c.state = CommandState_Done
68+
if c.whenDone != nil {
69+
c.whenDone()
70+
}
6871

6972
default:
7073
c.invalidStateTransition(CommandState_Done)

workflow/subworkflow.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,31 @@ func createSubWorkflowInstance[TResult any](ctx sync.Context, options SubWorkflo
8585
span.Marshal(metadata)
8686

8787
cmd := command.NewScheduleSubWorkflowCommand(scheduleEventID, wfState.Instance(), options.InstanceID, name, inputs, metadata)
88+
8889
wfState.AddCommand(cmd)
8990
wfState.TrackFuture(scheduleEventID, workflowstate.AsDecodingSettable(cv, f))
9091

9192
// Check if the channel is cancelable
9293
if c, cancelable := ctx.Done().(sync.CancelChannel); cancelable {
93-
c.AddReceiveCallback(func(v struct{}, ok bool) {
94-
cmd.Cancel()
95-
if cmd.State() == command.CommandState_Canceled {
96-
// Remove the sub-workflow future from the workflow state and mark it as canceled if it hasn't already fired
97-
if fi, ok := f.(sync.FutureInternal[TResult]); ok {
98-
if !fi.Ready() {
99-
wfState.RemoveFuture(scheduleEventID)
100-
f.Set(*new(TResult), sync.Canceled)
94+
cancelReceiver := &sync.Receiver[struct{}]{
95+
Receive: func(v struct{}, ok bool) {
96+
cmd.Cancel()
97+
if cmd.State() == command.CommandState_Canceled {
98+
// Remove the sub-workflow future from the workflow state and mark it as canceled if it hasn't already fired
99+
if fi, ok := f.(sync.FutureInternal[TResult]); ok {
100+
if !fi.Ready() {
101+
wfState.RemoveFuture(scheduleEventID)
102+
f.Set(*new(TResult), sync.Canceled)
103+
}
101104
}
102105
}
103-
}
106+
},
107+
}
108+
109+
c.AddReceiveCallback(cancelReceiver)
110+
111+
cmd.WhenDone(func() {
112+
c.RemoveReceiveCallback(cancelReceiver)
104113
})
105114
}
106115

workflow/timer.go

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,27 @@ func ScheduleTimer(ctx Context, delay time.Duration) Future[struct{}] {
2525

2626
scheduleEventID := wfState.GetNextScheduleEventID()
2727
at := Now(ctx).Add(delay)
28+
2829
timerCmd := command.NewScheduleTimerCommand(scheduleEventID, at)
2930
wfState.AddCommand(timerCmd)
30-
3131
wfState.TrackFuture(scheduleEventID, workflowstate.AsDecodingSettable(converter.GetConverter(ctx), f))
3232

33+
cancelReceiver := &sync.Receiver[struct{}]{
34+
Receive: func(v struct{}, ok bool) {
35+
timerCmd.Cancel()
36+
37+
// Remove the timer future from the workflow state and mark it as canceled if it hasn't already fired. This is different
38+
// from subworkflow behavior, where we want to wait for the subworkflow to complete before proceeding. Here we can
39+
// continue right away.
40+
if fi, ok := f.(sync.FutureInternal[struct{}]); ok {
41+
if !fi.Ready() {
42+
wfState.RemoveFuture(scheduleEventID)
43+
f.Set(v, sync.Canceled)
44+
}
45+
}
46+
},
47+
}
48+
3349
ctx, span := workflowtracer.Tracer(ctx).Start(ctx, "ScheduleTimer",
3450
trace.WithAttributes(
3551
attribute.Int64("duration_ms", int64(delay/time.Millisecond)),
@@ -42,26 +58,10 @@ func ScheduleTimer(ctx Context, delay time.Duration) Future[struct{}] {
4258
if c, cancelable := ctx.Done().(sync.CancelChannel); cancelable {
4359
// Register a callback for when it's canceled. The only operation on the `Done` channel
4460
// is that it's closed when the context is canceled.
45-
canceled := false
46-
47-
c.AddReceiveCallback(func(v struct{}, ok bool) {
48-
// Ignore any future cancelation events for this timer
49-
if canceled {
50-
return
51-
}
52-
canceled = true
61+
c.AddReceiveCallback(cancelReceiver)
5362

54-
timerCmd.Cancel()
55-
56-
// Remove the timer future from the workflow state and mark it as canceled if it hasn't already fired. This is different
57-
// from subworkflow behavior, where we want to wait for the subworkflow to complete before proceeding. Here we can
58-
// continue right away.
59-
if fi, ok := f.(sync.FutureInternal[struct{}]); ok {
60-
if !fi.Ready() {
61-
wfState.RemoveFuture(scheduleEventID)
62-
f.Set(v, sync.Canceled)
63-
}
64-
}
63+
timerCmd.WhenDone(func() {
64+
c.RemoveReceiveCallback(cancelReceiver)
6565
})
6666
}
6767

workflow/timer_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package workflow
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/benbjohnson/clock"
8+
"github.com/cschleiden/go-workflows/internal/converter"
9+
"github.com/cschleiden/go-workflows/internal/core"
10+
"github.com/cschleiden/go-workflows/internal/logger"
11+
"github.com/cschleiden/go-workflows/internal/sync"
12+
"github.com/cschleiden/go-workflows/internal/workflowstate"
13+
"github.com/cschleiden/go-workflows/internal/workflowtracer"
14+
"github.com/stretchr/testify/require"
15+
"go.opentelemetry.io/otel/trace"
16+
)
17+
18+
func Test_Timer_Cancellation(t *testing.T) {
19+
state := workflowstate.NewWorkflowState(core.NewWorkflowInstance("a", ""), logger.NewDefaultLogger(), clock.New())
20+
21+
ctx, cancel := sync.WithCancel(sync.Background())
22+
ctx = converter.WithConverter(ctx, converter.DefaultConverter)
23+
ctx = workflowstate.WithWorkflowState(ctx, state)
24+
ctx = workflowtracer.WithWorkflowTracer(ctx, workflowtracer.New(trace.NewNoopTracerProvider().Tracer("test")))
25+
26+
c := sync.NewCoroutine(ctx, func(ctx sync.Context) error {
27+
f := ScheduleTimer(ctx, time.Second*1)
28+
f.Get(ctx)
29+
30+
// Block workflow
31+
sync.NewFuture[int]().Get(ctx)
32+
33+
return nil
34+
})
35+
c.Execute()
36+
require.False(t, c.Finished())
37+
38+
// Fire timer
39+
cmd := state.CommandByScheduleEventID(1)
40+
cmd.Commit()
41+
cmd.Done()
42+
fs, ok := state.FutureByScheduleEventID(1)
43+
require.True(t, ok)
44+
fs(nil, nil)
45+
46+
c.Execute()
47+
require.False(t, c.Finished())
48+
49+
cancel()
50+
}

0 commit comments

Comments
 (0)