Skip to content

Commit 2cc2eb7

Browse files
committed
Check for more error conditions
1 parent 6d09814 commit 2cc2eb7

File tree

6 files changed

+87
-54
lines changed

6 files changed

+87
-54
lines changed

backend/test/e2e.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,12 @@ func EndToEndBackendTest(t *testing.T, setup func() TestBackend, teardown func(b
8282
}
8383
register(t, ctx, w, []interface{}{wf}, nil)
8484

85-
output, err := runWorkflowWithResult[int](t, ctx, c, wf)
85+
instance, err := c.CreateWorkflowInstance(ctx, client.WorkflowInstanceOptions{
86+
InstanceID: uuid.NewString(),
87+
}, wf)
8688

87-
require.Zero(t, output)
88-
require.ErrorContains(t, err, "converting workflow inputs: mismatched argument count: expected 1, got 0")
89+
require.Nil(t, instance)
90+
require.ErrorContains(t, err, "mismatched argument count: expected 1, got 0")
8991
},
9092
},
9193
{
@@ -123,7 +125,7 @@ func EndToEndBackendTest(t *testing.T, setup func() TestBackend, teardown func(b
123125
output, err := runWorkflowWithResult[int](t, ctx, c, wf)
124126

125127
require.Zero(t, output)
126-
require.ErrorContains(t, err, "converting activity inputs: mismatched argument count: expected 2, got 1")
128+
require.ErrorContains(t, err, "mismatched argument count: expected 2, got 1")
127129
},
128130
},
129131
{

client/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ func New(backend backend.Backend) Client {
5656

5757
func (c *client) CreateWorkflowInstance(ctx context.Context, options WorkflowInstanceOptions, wf workflow.Workflow, args ...interface{}) (*workflow.Instance, error) {
5858
// Check arguments
59-
if !fn.ParamsMatch(wf, 1, args...) {
60-
return nil, errors.New("arguments do not match workflow parameters")
59+
if err := fn.ParamsMatch(wf, 1, args...); err != nil {
60+
return nil, err
6161
}
6262

6363
inputs, err := a.ArgsToInputs(c.backend.Converter(), args...)

internal/fn/fn.go

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package fn
22

33
import (
4+
"errors"
5+
"fmt"
46
"reflect"
57
"runtime"
68
"strings"
@@ -16,28 +18,46 @@ func Name(i interface{}) string {
1618
return strings.TrimSuffix(fnName, "-fm")
1719
}
1820

19-
func ReturnTypeMatch[TResult any](fn interface{}) bool {
21+
func ReturnTypeMatch[TResult any](fn interface{}) error {
2022
fnType := reflect.TypeOf(fn)
2123
if fnType.Kind() != reflect.Func {
22-
return false
24+
return errors.New("not a function")
2325
}
2426

25-
if fnType.NumOut() == 1 {
26-
return true
27+
if fnType.NumOut() < 1 {
28+
return errors.New("function has no return value, must return at least (error) or (result, error)")
2729
}
2830

29-
t := *new(TResult)
30-
return fnType.Out(0) == reflect.TypeOf(t)
31+
if fnType.NumOut() > 2 {
32+
return errors.New("function has too many return values, must return at most (error) or (result, error)")
33+
}
34+
35+
errorPosition := 0
36+
if fnType.NumOut() == 2 {
37+
errorPosition = 1
38+
39+
t := *new(TResult)
40+
if fnType.Out(0) != reflect.TypeOf(t) {
41+
return fmt.Errorf("function must return %s, got %s", reflect.TypeOf(t), fnType.Out(0))
42+
}
43+
}
44+
45+
// Check if return is error
46+
if fnType.Out(errorPosition) != reflect.TypeOf((*error)(nil)).Elem() {
47+
return fmt.Errorf("function must return error, got %s", fnType.Out(errorPosition))
48+
}
49+
50+
return nil
3151
}
3252

33-
func ParamsMatch(fn interface{}, skip int, args ...interface{}) bool {
53+
func ParamsMatch(fn interface{}, skip int, args ...interface{}) error {
3454
fnType := reflect.TypeOf(fn)
3555
if fnType.Kind() != reflect.Func {
36-
return false
56+
return errors.New("not a function")
3757
}
3858

3959
if fnType.NumIn() != skip+len(args) {
40-
return false
60+
return fmt.Errorf("mismatched argument count: expected %d, got %d", fnType.NumIn()-skip, len(args))
4161
}
4262

4363
for i, arg := range args {
@@ -47,9 +67,9 @@ func ParamsMatch(fn interface{}, skip int, args ...interface{}) bool {
4767
}
4868

4969
if fnType.In(skip+i) != reflect.TypeOf(arg) {
50-
return false
70+
return fmt.Errorf("mismatched argument type: expected %s, got %s", fnType.In(skip+i), reflect.TypeOf(arg))
5171
}
5272
}
5373

54-
return true
74+
return nil
5575
}

internal/fn/fn_test.go

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,48 +68,54 @@ func errorReturn() error {
6868
func TestReturnTypeMatch(t *testing.T) {
6969
tests := []struct {
7070
name string
71-
fn func() bool
72-
want bool
71+
fn func() error
72+
want string
7373
}{
7474
{
7575
name: "int match",
76-
fn: func() bool {
76+
fn: func() error {
7777
return ReturnTypeMatch[int](intReturn)
7878
},
79-
want: true,
79+
want: "",
8080
},
8181
{
8282
name: "string match",
83-
fn: func() bool {
83+
fn: func() error {
8484
return ReturnTypeMatch[string](stringReturn)
8585
},
86-
want: true,
86+
want: "",
8787
},
8888
{
8989
name: "int mismatch",
90-
fn: func() bool {
90+
fn: func() error {
9191
return ReturnTypeMatch[string](intReturn)
9292
},
93+
want: "function must return string, got int",
9394
},
9495
{
9596
name: "no param",
96-
fn: func() bool {
97+
fn: func() error {
9798
return ReturnTypeMatch[any](errorReturn)
9899
},
99-
want: true,
100+
want: "",
100101
},
101102
{
102103
name: "no param mismatch",
103-
fn: func() bool {
104+
fn: func() error {
104105
return ReturnTypeMatch[int](errorReturn)
105106
},
106-
want: true,
107+
want: "",
107108
},
108109
}
109110
for _, tt := range tests {
110111
t.Run(tt.name, func(t *testing.T) {
111112
got := tt.fn()
112-
require.Equal(t, tt.want, got)
113+
if tt.want == "" {
114+
require.NoError(t, got)
115+
} else {
116+
require.Error(t, got)
117+
require.Equal(t, tt.want, got.Error())
118+
}
113119
})
114120
}
115121
}
@@ -129,63 +135,68 @@ func mixedParams(context.Context, int, string) {
129135
func TestParamsMatch(t *testing.T) {
130136
tests := []struct {
131137
name string
132-
fn func() bool
133-
want bool
138+
fn func() error
139+
want string
134140
}{
135141
{
136142
name: "int match",
137-
fn: func() bool {
143+
fn: func() error {
138144
return ParamsMatch(intParam, 0, 42)
139145
},
140-
want: true,
146+
want: "",
141147
},
142148
{
143149
name: "int mismatch",
144-
fn: func() bool {
150+
fn: func() error {
145151
return ParamsMatch(intParam, 0, "")
146152
},
147-
want: false,
153+
want: "mismatched argument type: expected int, got string",
148154
},
149155
{
150156
name: "string mismatch",
151-
fn: func() bool {
157+
fn: func() error {
152158
return ParamsMatch(stringParam, 0, 42)
153159
},
154-
want: false,
160+
want: "mismatched argument type: expected string, got int",
155161
},
156162
{
157163
name: "interface{} ignored",
158-
fn: func() bool {
164+
fn: func() error {
159165
return ParamsMatch(interfaceParam, 0, "", 23, 42)
160166
},
161-
want: true,
167+
want: "",
162168
},
163169
{
164170
name: "mixed params",
165-
fn: func() bool {
171+
fn: func() error {
166172
return ParamsMatch(mixedParams, 1, 42, "")
167173
},
168-
want: true,
174+
want: "",
169175
},
170176
{
171177
name: "mixed params - no skip",
172-
fn: func() bool {
178+
fn: func() error {
173179
return ParamsMatch(mixedParams, 0, 42, "")
174180
},
175-
want: false,
181+
want: "mismatched argument count: expected 2, got 3",
176182
},
177183
{
178184
name: "mixed params - wrong params",
179-
fn: func() bool {
185+
fn: func() error {
180186
return ParamsMatch(mixedParams, 1, "", 42)
181187
},
182-
want: false,
188+
want: "mismatched argument type: expected int, got string",
183189
},
184190
}
185191
for _, tt := range tests {
186192
t.Run(tt.name, func(t *testing.T) {
187193
got := tt.fn()
188-
require.Equal(t, tt.want, got)
194+
if tt.want == "" {
195+
require.NoError(t, got)
196+
} else {
197+
require.Error(t, got)
198+
require.Equal(t, tt.want, got.Error())
199+
}
189200
})
190201
}
191202
}

workflow/activity.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ func executeActivity[TResult any](ctx Context, options ActivityOptions, attempt
3939
}
4040

4141
// Check return type
42-
if !fn.ReturnTypeMatch[TResult](activity) {
43-
f.Set(*new(TResult), fmt.Errorf("activity return type does not match expected type"))
42+
if err := fn.ReturnTypeMatch[TResult](activity); err != nil {
43+
f.Set(*new(TResult), err)
4444
return f
4545
}
4646

4747
// Check arguments
48-
if !fn.ParamsMatch(activity, 1, args...) {
49-
f.Set(*new(TResult), fmt.Errorf("activity arguments do not match expected types"))
48+
if err := fn.ParamsMatch(activity, 1, args...); err != nil {
49+
f.Set(*new(TResult), err)
5050
return f
5151
}
5252

workflow/subworkflow.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ func createSubWorkflowInstance[TResult any](ctx sync.Context, options SubWorkflo
4949
}
5050

5151
// Check return type
52-
if !fn.ReturnTypeMatch[TResult](wf) {
53-
f.Set(*new(TResult), fmt.Errorf("subworkflow return type does not match expected type"))
52+
if err := fn.ReturnTypeMatch[TResult](wf); err != nil {
53+
f.Set(*new(TResult), err)
5454
return f
5555
}
5656

5757
// Check arguments
58-
if !fn.ParamsMatch(wf, 1, args...) {
59-
f.Set(*new(TResult), fmt.Errorf("subworkflow arguments do not match expected types"))
58+
if err := fn.ParamsMatch(wf, 1, args...); err != nil {
59+
f.Set(*new(TResult), err)
6060
return f
6161
}
6262

0 commit comments

Comments
 (0)