Skip to content

Commit 9452d21

Browse files
authored
Simplify sub-workflow cancellation propagation
1 parent 3a9fa1b commit 9452d21

File tree

5 files changed

+73
-98
lines changed

5 files changed

+73
-98
lines changed

backend/mysql/mysql.go

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -102,35 +102,16 @@ func (b *mysqlBackend) CancelWorkflowInstance(ctx context.Context, instance *wor
102102
// Cancel workflow instance
103103
// TODO: Combine this with the event insertion
104104
res := tx.QueryRowContext(ctx, "SELECT 1 FROM `instances` WHERE instance_id = ? LIMIT 1", instanceID)
105-
if err := res.Scan(nil); err == sql.ErrNoRows {
106-
return backend.ErrInstanceNotFound
107-
}
108-
109-
// Recursively, find any sub-workflow instance to cancel
110-
toCancel := []string{instance.InstanceID}
111-
112-
for len(toCancel) > 0 {
113-
toCancelID := toCancel[0]
114-
toCancel = toCancel[1:]
115-
116-
if err := insertNewEvents(ctx, tx, toCancelID, []history.Event{*event}); err != nil {
117-
return fmt.Errorf("inserting cancellation event: %w", err)
118-
}
119-
120-
rows, err := tx.QueryContext(ctx, "SELECT instance_id FROM `instances` WHERE parent_instance_id = ? AND completed_at IS NULL", toCancelID)
121-
defer rows.Close()
122-
if err != nil {
123-
return fmt.Errorf("finding sub-workflow instances: %w", err)
105+
if err := res.Scan(new(int)); err != nil {
106+
if err == sql.ErrNoRows {
107+
return backend.ErrInstanceNotFound
124108
}
125109

126-
for rows.Next() {
127-
var subWorkflowInstanceID string
128-
if err := rows.Scan(&subWorkflowInstanceID); err != nil {
129-
return fmt.Errorf("geting workflow instance for canceling: %w", err)
130-
}
110+
return err
111+
}
131112

132-
toCancel = append(toCancel, subWorkflowInstanceID)
133-
}
113+
if err := insertNewEvents(ctx, tx, instanceID, []history.Event{*event}); err != nil {
114+
return fmt.Errorf("inserting cancellation event: %w", err)
134115
}
135116

136117
return tx.Commit()

backend/redis/instance.go

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,30 +79,15 @@ func (rb *redisBackend) GetWorkflowInstanceState(ctx context.Context, instance *
7979
}
8080

8181
func (rb *redisBackend) CancelWorkflowInstance(ctx context.Context, instance *core.WorkflowInstance, event *history.Event) error {
82+
// Read the instance to check if it exists
8283
_, err := readInstance(ctx, rb.rdb, instance.InstanceID)
8384
if err != nil {
8485
return err
8586
}
8687

87-
// Recursively, find any sub-workflow instances to cancel
88-
toCancel := make([]*core.WorkflowInstance, 0)
89-
toCancel = append(toCancel, instance)
90-
for len(toCancel) > 0 {
91-
instance := toCancel[0]
92-
toCancel = toCancel[1:]
93-
94-
// Cancel instance
95-
if err := rb.addWorkflowInstanceEvent(ctx, instance, event); err != nil {
96-
return fmt.Errorf("adding cancellation event to workflow instance: %w", err)
97-
}
98-
99-
// Find sub-workflows
100-
subInstances, err := subWorkflowInstances(ctx, rb.rdb, instance)
101-
if err != nil {
102-
return fmt.Errorf("finding sub-workflow instances for cancellation: %w", err)
103-
}
104-
105-
toCancel = append(toCancel, subInstances...)
88+
// Cancel instance
89+
if err := rb.addWorkflowInstanceEvent(ctx, instance, event); err != nil {
90+
return fmt.Errorf("adding cancellation event to workflow instance: %w", err)
10691
}
10792

10893
return nil
@@ -132,7 +117,7 @@ func createInstance(ctx context.Context, rdb redis.UniversalClient, instance *co
132117

133118
ok, err := rdb.SetNX(ctx, key, string(b), 0).Result()
134119
if err != nil {
135-
return fmt.Errorf("storeing instance: %w", err)
120+
return fmt.Errorf("storing instance: %w", err)
136121
}
137122

138123
if !ignoreDuplicate && !ok {

backend/sqlite/sqlite.go

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -135,35 +135,16 @@ func (sb *sqliteBackend) CancelWorkflowInstance(ctx context.Context, instance *w
135135

136136
// TODO: Combine with event insertion
137137
res := tx.QueryRowContext(ctx, "SELECT 1 FROM `instances` WHERE id = ? LIMIT 1", instanceID)
138-
if err := res.Scan(nil); err == sql.ErrNoRows {
139-
return backend.ErrInstanceNotFound
140-
}
141-
142-
// Recursively, find any sub-workflow instance to cancel
143-
toCancel := []string{instance.InstanceID}
144-
145-
for len(toCancel) > 0 {
146-
toCancelID := toCancel[0]
147-
toCancel = toCancel[1:]
148-
149-
if err := insertNewEvents(ctx, tx, toCancelID, []history.Event{*event}); err != nil {
150-
return fmt.Errorf("inserting cancellation event: %w", err)
151-
}
152-
153-
rows, err := tx.QueryContext(ctx, "SELECT id FROM `instances` WHERE parent_instance_id = ? AND completed_at IS NULL", toCancelID)
154-
defer rows.Close()
155-
if err != nil {
156-
return fmt.Errorf("finding sub-workflow instances: %w", err)
138+
if err := res.Scan(new(int)); err != nil {
139+
if err == sql.ErrNoRows {
140+
return backend.ErrInstanceNotFound
157141
}
158142

159-
for rows.Next() {
160-
var subWorkflowInstanceID string
161-
if err := rows.Scan(&subWorkflowInstanceID); err != nil {
162-
return fmt.Errorf("geting workflow instance for canceling: %w", err)
163-
}
143+
return err
144+
}
164145

165-
toCancel = append(toCancel, subWorkflowInstanceID)
166-
}
146+
if err := insertNewEvents(ctx, tx, instanceID, []history.Event{*event}); err != nil {
147+
return fmt.Errorf("inserting cancellation event: %w", err)
167148
}
168149

169150
return tx.Commit()

backend/test/backendtest.go

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -210,32 +210,6 @@ func BackendTest(t *testing.T, setup func() backend.Backend, teardown func(b bac
210210
require.Equal(t, history.EventType_WorkflowExecutionCanceled, task.NewEvents[len(task.NewEvents)-1].Type)
211211
},
212212
},
213-
{
214-
name: "CancelWorkflow_CancelsSpawnedSubWorkflows",
215-
f: func(t *testing.T, ctx context.Context, b backend.Backend) {
216-
c := client.New(b)
217-
instance := core.NewWorkflowInstance(uuid.NewString(), uuid.NewString())
218-
startWorkflow(t, ctx, b, c, instance)
219-
220-
subInstance1 := core.NewSubWorkflowInstance(uuid.NewString(), uuid.NewString(), instance.InstanceID, 1)
221-
startWorkflow(t, ctx, b, c, subInstance1)
222-
223-
subInstance2 := core.NewSubWorkflowInstance(uuid.NewString(), uuid.NewString(), instance.InstanceID, 2)
224-
startWorkflow(t, ctx, b, c, subInstance2)
225-
226-
err := c.CancelWorkflowInstance(ctx, instance)
227-
require.NoError(t, err)
228-
229-
for i := 0; i < 3; i++ {
230-
task, err := b.GetWorkflowTask(ctx)
231-
require.NoError(t, err)
232-
require.Equal(t, history.EventType_WorkflowExecutionCanceled, task.NewEvents[len(task.NewEvents)-1].Type)
233-
234-
err = b.CompleteWorkflowTask(ctx, task.ID, task.WorkflowInstance, backend.WorkflowStateActive, task.NewEvents, []history.Event{}, []history.WorkflowEvent{})
235-
require.NoError(t, err)
236-
}
237-
},
238-
},
239213
{
240214
name: "CompleteWorkflowTask_SendsInstanceEvents",
241215
f: func(t *testing.T, ctx context.Context, b backend.Backend) {

backend/test/e2e.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,60 @@ func EndToEndBackendTest(t *testing.T, setup func() backend.Backend, teardown fu
9191
require.ErrorContains(t, err, "converting activity inputs: mismatched argument count: expected 2, got 1")
9292
},
9393
},
94+
{
95+
name: "SubWorkflow_PropagateCancellation",
96+
f: func(t *testing.T, ctx context.Context, c client.Client, w worker.Worker) {
97+
canceled := 0
98+
99+
swf := func(ctx workflow.Context, i int) (int, error) {
100+
err := workflow.Sleep(ctx, time.Second*10)
101+
if err != nil {
102+
if err != workflow.Canceled {
103+
return 0, err
104+
}
105+
}
106+
107+
if ctx.Err() != nil && ctx.Err() == workflow.Canceled {
108+
canceled++
109+
}
110+
111+
return i * 2, nil
112+
}
113+
wf := func(ctx workflow.Context) (int, error) {
114+
swfs := make([]workflow.Future[int], 0)
115+
116+
swfs = append(swfs, workflow.CreateSubWorkflowInstance[int](ctx, workflow.DefaultSubWorkflowOptions, swf, 1))
117+
swfs = append(swfs, workflow.CreateSubWorkflowInstance[int](ctx, workflow.DefaultSubWorkflowOptions, swf, 2))
118+
119+
r := 0
120+
121+
for _, f := range swfs {
122+
sr, err := f.Get(ctx)
123+
if err != nil && err != workflow.Canceled {
124+
return 0, err
125+
}
126+
127+
r = r + sr
128+
}
129+
130+
if ctx.Err() != nil && ctx.Err() == workflow.Canceled {
131+
canceled++
132+
}
133+
134+
return r, nil
135+
}
136+
register(t, ctx, w, []interface{}{wf, swf}, nil)
137+
138+
instance := runWorkflow(t, ctx, c, wf)
139+
require.NoError(t, c.CancelWorkflowInstance(ctx, instance))
140+
141+
r, err := client.GetWorkflowResult[int](ctx, c, instance, time.Second*5)
142+
require.NoError(t, err)
143+
require.Equal(t, 6, r)
144+
145+
require.Equal(t, 3, canceled)
146+
},
147+
},
94148
}
95149

96150
for _, tt := range tests {

0 commit comments

Comments
 (0)