Skip to content

Commit 678b74a

Browse files
committed
Check params when scheduling workflows or activities
1 parent 2e989be commit 678b74a

File tree

9 files changed

+182
-1
lines changed

9 files changed

+182
-1
lines changed

backend/options.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@ type Options struct {
2424

2525
StickyTimeout time.Duration
2626

27+
// WorkflowLockTimeout determines how long a workflow task can be locked for. If the workflow task is not completed
28+
// by that timeframe, it's considered abandoned and another worker might pick it up.
29+
//
30+
// For long running workflow tasks, combine this with heartbearts.
2731
WorkflowLockTimeout time.Duration
2832

33+
// ActivityLockTimeout determines how long an activity task can be locked for. If the activity task is not completed
34+
// by that timeframe, it's considered abandoned and another worker might pick it up
2935
ActivityLockTimeout time.Duration
3036
}
3137

client/client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ func New(backend backend.Backend) Client {
5555
}
5656

5757
func (c *client) CreateWorkflowInstance(ctx context.Context, options WorkflowInstanceOptions, wf workflow.Workflow, args ...interface{}) (*workflow.Instance, error) {
58+
// Check arguments
59+
if !fn.ParamsMatch(wf, 1, args...) {
60+
return nil, errors.New("arguments do not match workflow parameters")
61+
}
62+
5863
inputs, err := a.ArgsToInputs(c.backend.Converter(), args...)
5964
if err != nil {
6065
return nil, fmt.Errorf("converting arguments: %w", err)

client/client_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,33 @@ import (
1212
"github.com/cschleiden/go-workflows/internal/core"
1313
"github.com/cschleiden/go-workflows/internal/history"
1414
"github.com/cschleiden/go-workflows/internal/logger"
15+
"github.com/cschleiden/go-workflows/workflow"
1516
"github.com/google/uuid"
1617
"github.com/stretchr/testify/mock"
1718
"github.com/stretchr/testify/require"
1819
)
1920

21+
func Test_Client_CreateWorkflowInstance_ParamMismatch(t *testing.T) {
22+
wf := func(workflow.Context, int) (int, error) {
23+
return 0, nil
24+
}
25+
26+
ctx := context.Background()
27+
28+
b := &backend.MockBackend{}
29+
c := &client{
30+
backend: b,
31+
clock: clock.New(),
32+
}
33+
34+
result, err := c.CreateWorkflowInstance(ctx, WorkflowInstanceOptions{
35+
InstanceID: "id",
36+
}, wf, "foo")
37+
require.Zero(t, result)
38+
require.EqualError(t, err, "arguments do not match workflow parameters")
39+
b.AssertExpectations(t)
40+
}
41+
2042
func Test_Client_GetWorkflowResultTimeout(t *testing.T) {
2143
instance := core.NewWorkflowInstance(uuid.NewString(), "test")
2244

internal/fn/fn.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,22 @@ func ReturnTypeMatch[TResult any](fn interface{}) bool {
2929
t := *new(TResult)
3030
return fnType.Out(0) == reflect.TypeOf(t)
3131
}
32+
33+
func ParamsMatch(fn interface{}, skip int, args ...interface{}) bool {
34+
fnType := reflect.TypeOf(fn)
35+
if fnType.Kind() != reflect.Func {
36+
return false
37+
}
38+
39+
if fnType.NumIn() != skip+len(args) {
40+
return false
41+
}
42+
43+
for i, arg := range args {
44+
if fnType.In(skip+i) != reflect.TypeOf(arg) {
45+
return false
46+
}
47+
}
48+
49+
return true
50+
}

internal/fn/fn_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,69 @@ func TestReturnTypeMatch(t *testing.T) {
113113
})
114114
}
115115
}
116+
117+
func intParam(int) {
118+
}
119+
120+
func stringParam(string) {
121+
}
122+
123+
func mixedParams(context.Context, int, string) {
124+
}
125+
126+
func TestParamsMatch(t *testing.T) {
127+
tests := []struct {
128+
name string
129+
fn func() bool
130+
want bool
131+
}{
132+
{
133+
name: "int match",
134+
fn: func() bool {
135+
return ParamsMatch(intParam, 0, 42)
136+
},
137+
want: true,
138+
},
139+
{
140+
name: "int mismatch",
141+
fn: func() bool {
142+
return ParamsMatch(intParam, 0, "")
143+
},
144+
want: false,
145+
},
146+
{
147+
name: "string mismatch",
148+
fn: func() bool {
149+
return ParamsMatch(stringParam, 0, 42)
150+
},
151+
want: false,
152+
},
153+
{
154+
name: "mixed params",
155+
fn: func() bool {
156+
return ParamsMatch(mixedParams, 1, 42, "")
157+
},
158+
want: true,
159+
},
160+
{
161+
name: "mixed params - no skip",
162+
fn: func() bool {
163+
return ParamsMatch(mixedParams, 0, 42, "")
164+
},
165+
want: false,
166+
},
167+
{
168+
name: "mixed params - wrong params",
169+
fn: func() bool {
170+
return ParamsMatch(mixedParams, 1, "", 42)
171+
},
172+
want: false,
173+
},
174+
}
175+
for _, tt := range tests {
176+
t.Run(tt.name, func(t *testing.T) {
177+
got := tt.fn()
178+
require.Equal(t, tt.want, got)
179+
})
180+
}
181+
}

workflow/activity.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,18 @@ func executeActivity[TResult any](ctx Context, options ActivityOptions, attempt
3838
return f
3939
}
4040

41+
// Check return type
4142
if !fn.ReturnTypeMatch[TResult](activity) {
4243
f.Set(*new(TResult), fmt.Errorf("activity return type does not match expected type"))
4344
return f
4445
}
4546

47+
// Check arguments
48+
if !fn.ParamsMatch(activity, 1, args...) {
49+
f.Set(*new(TResult), fmt.Errorf("activity arguments do not match expected types"))
50+
return f
51+
}
52+
4653
cv := converter.GetConverter(ctx)
4754
inputs, err := a.ArgsToInputs(cv, args...)
4855
if err != nil {

workflow/activity_test.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"go.opentelemetry.io/otel/trace"
1515
)
1616

17-
func Test_executeActivity_ParamMismatch(t *testing.T) {
17+
func Test_executeActivity_ResultMismatch(t *testing.T) {
1818
a := func(ctx Context) (int, error) {
1919
return 42, nil
2020
}
@@ -38,3 +38,27 @@ func Test_executeActivity_ParamMismatch(t *testing.T) {
3838
c.Execute()
3939
require.True(t, c.Finished())
4040
}
41+
func Test_executeActivity_ParamMismatch(t *testing.T) {
42+
a := func(ctx Context, s string, n int) (int, error) {
43+
return 42, nil
44+
}
45+
46+
ctx := sync.Background()
47+
ctx = converter.WithConverter(ctx, converter.DefaultConverter)
48+
ctx = workflowstate.WithWorkflowState(
49+
ctx,
50+
workflowstate.NewWorkflowState(core.NewWorkflowInstance("a", ""), logger.NewDefaultLogger(), clock.New()),
51+
)
52+
ctx = workflowtracer.WithWorkflowTracer(ctx, workflowtracer.New(trace.NewNoopTracerProvider().Tracer("test")))
53+
54+
c := sync.NewCoroutine(ctx, func(ctx sync.Context) error {
55+
f := executeActivity[int](ctx, DefaultActivityOptions, 1, a)
56+
_, err := f.Get(ctx)
57+
require.Error(t, err)
58+
59+
return nil
60+
})
61+
62+
c.Execute()
63+
require.True(t, c.Finished())
64+
}

workflow/subworkflow.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,18 @@ func createSubWorkflowInstance[TResult any](ctx sync.Context, options SubWorkflo
4848
return f
4949
}
5050

51+
// Check return type
5152
if !fn.ReturnTypeMatch[TResult](wf) {
5253
f.Set(*new(TResult), fmt.Errorf("subworkflow return type does not match expected type"))
5354
return f
5455
}
5556

57+
// Check arguments
58+
if !fn.ParamsMatch(wf, 1, args...) {
59+
f.Set(*new(TResult), fmt.Errorf("subworkflow arguments do not match expected types"))
60+
return f
61+
}
62+
5663
name := fn.Name(wf)
5764

5865
cv := converter.GetConverter(ctx)

workflow/subworkflow_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,31 @@ import (
1515
)
1616

1717
func Test_createSubWorkflowInstance_ParamMismatch(t *testing.T) {
18+
wf := func(Context, int) (int, error) {
19+
return 42, nil
20+
}
21+
22+
ctx := sync.Background()
23+
ctx = converter.WithConverter(ctx, converter.DefaultConverter)
24+
ctx = workflowstate.WithWorkflowState(
25+
ctx,
26+
workflowstate.NewWorkflowState(core.NewWorkflowInstance("a", ""), logger.NewDefaultLogger(), clock.New()),
27+
)
28+
ctx = workflowtracer.WithWorkflowTracer(ctx, workflowtracer.New(trace.NewNoopTracerProvider().Tracer("test")))
29+
30+
c := sync.NewCoroutine(ctx, func(ctx sync.Context) error {
31+
f := createSubWorkflowInstance[int](ctx, DefaultSubWorkflowOptions, 1, wf, "foo")
32+
_, err := f.Get(ctx)
33+
require.Error(t, err)
34+
35+
return nil
36+
})
37+
38+
c.Execute()
39+
require.True(t, c.Finished())
40+
}
41+
42+
func Test_createSubWorkflowInstance_ReturnMismatch(t *testing.T) {
1843
wf := func(ctx Context) (int, error) {
1944
return 42, nil
2045
}

0 commit comments

Comments
 (0)