Skip to content

Commit c528042

Browse files
authored
Merge pull request #168 from cschleiden/cschleiden/check-for-context
Move param check to args package
2 parents 4151dd5 + babf53e commit c528042

File tree

7 files changed

+228
-211
lines changed

7 files changed

+228
-211
lines changed

client/client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func New(backend backend.Backend) Client {
5656

5757
func (c *client) CreateWorkflowInstance(ctx context.Context, options WorkflowInstanceOptions, wf workflow.Workflow, args ...interface{}) (*workflow.Instance, error) {
5858
// Check arguments
59-
if err := fn.ParamsMatch(wf, 1, args...); err != nil {
59+
if err := a.ParamsMatch(wf, args...); err != nil {
6060
return nil, err
6161
}
6262

internal/args/args.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package args
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"reflect"
78

@@ -66,6 +67,80 @@ func InputsToArgs(c converter.Converter, fn reflect.Value, inputs []payload.Payl
6667
return args, addContext, nil
6768
}
6869

70+
func ReturnTypeMatch[TResult any](fn interface{}) error {
71+
fnType := reflect.TypeOf(fn)
72+
if fnType.Kind() != reflect.Func {
73+
return errors.New("not a function")
74+
}
75+
76+
if fnType.NumOut() < 1 {
77+
return errors.New("function has no return value, must return at least (error) or (result, error)")
78+
}
79+
80+
if fnType.NumOut() > 2 {
81+
return errors.New("function has too many return values, must return at most (error) or (result, error)")
82+
}
83+
84+
errorPosition := 0
85+
if fnType.NumOut() == 2 {
86+
errorPosition = 1
87+
88+
t := *new(TResult)
89+
if fnType.Out(0) != reflect.TypeOf(t) {
90+
return fmt.Errorf("function must return %s, got %s", reflect.TypeOf(t), fnType.Out(0))
91+
}
92+
}
93+
94+
// Check if return is error
95+
if fnType.Out(errorPosition) != reflect.TypeOf((*error)(nil)).Elem() {
96+
return fmt.Errorf("function must return error, got %s", fnType.Out(errorPosition))
97+
}
98+
99+
return nil
100+
}
101+
102+
func ParamsMatch(fn interface{}, args ...interface{}) error {
103+
fnType := reflect.TypeOf(fn)
104+
if fnType.Kind() != reflect.Func {
105+
return errors.New("not a function")
106+
}
107+
108+
requiredArguments := fnType.NumIn()
109+
needsContext := false
110+
if fnType.NumIn() > 0 {
111+
argT := fnType.In(0)
112+
113+
if IsOwnContext(argT) || isContext(argT) {
114+
needsContext = true
115+
requiredArguments--
116+
}
117+
}
118+
119+
if requiredArguments != len(args) {
120+
return fmt.Errorf("mismatched argument count: expected %d, got %d", requiredArguments, len(args))
121+
}
122+
123+
targetIdx := 0
124+
if needsContext {
125+
targetIdx = 1
126+
}
127+
128+
for _, arg := range args {
129+
// if target is interface{} skip
130+
if fnType.In(targetIdx).Kind() == reflect.Interface {
131+
continue
132+
}
133+
134+
if fnType.In(targetIdx) != reflect.TypeOf(arg) {
135+
return fmt.Errorf("mismatched argument type: expected %s, got %s", fnType.In(targetIdx), reflect.TypeOf(arg))
136+
}
137+
138+
targetIdx++
139+
}
140+
141+
return nil
142+
}
143+
69144
func IsOwnContext(inType reflect.Type) bool {
70145
contextElem := reflect.TypeOf((*sync.Context)(nil)).Elem()
71146
return inType != nil && inType.Implements(contextElem)

internal/args/args_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,151 @@ func TestInputsToArgs(t *testing.T) {
8989
})
9090
}
9191
}
92+
93+
func intReturn() (int, error) {
94+
return 0, nil
95+
}
96+
97+
func stringReturn() (string, error) {
98+
return "", nil
99+
}
100+
101+
func errorReturn() error {
102+
return nil
103+
}
104+
105+
func TestReturnTypeMatch(t *testing.T) {
106+
tests := []struct {
107+
name string
108+
fn func() error
109+
want string
110+
}{
111+
{
112+
name: "int match",
113+
fn: func() error {
114+
return ReturnTypeMatch[int](intReturn)
115+
},
116+
want: "",
117+
},
118+
{
119+
name: "string match",
120+
fn: func() error {
121+
return ReturnTypeMatch[string](stringReturn)
122+
},
123+
want: "",
124+
},
125+
{
126+
name: "int mismatch",
127+
fn: func() error {
128+
return ReturnTypeMatch[string](intReturn)
129+
},
130+
want: "function must return string, got int",
131+
},
132+
{
133+
name: "no param",
134+
fn: func() error {
135+
return ReturnTypeMatch[any](errorReturn)
136+
},
137+
want: "",
138+
},
139+
{
140+
name: "no param mismatch",
141+
fn: func() error {
142+
return ReturnTypeMatch[int](errorReturn)
143+
},
144+
want: "",
145+
},
146+
}
147+
for _, tt := range tests {
148+
t.Run(tt.name, func(t *testing.T) {
149+
got := tt.fn()
150+
if tt.want == "" {
151+
require.NoError(t, got)
152+
} else {
153+
require.Error(t, got)
154+
require.Equal(t, tt.want, got.Error())
155+
}
156+
})
157+
}
158+
}
159+
160+
func intParam(int) {
161+
}
162+
163+
func stringParam(string) {
164+
}
165+
166+
func interfaceParam(string, interface{}, int) {
167+
}
168+
169+
func mixedParams(context.Context, int, string) {
170+
}
171+
172+
func TestParamsMatch(t *testing.T) {
173+
tests := []struct {
174+
name string
175+
fn func() error
176+
want string
177+
}{
178+
{
179+
name: "int match",
180+
fn: func() error {
181+
return ParamsMatch(intParam, 42)
182+
},
183+
want: "",
184+
},
185+
{
186+
name: "int mismatch",
187+
fn: func() error {
188+
return ParamsMatch(intParam, "")
189+
},
190+
want: "mismatched argument type: expected int, got string",
191+
},
192+
{
193+
name: "string mismatch",
194+
fn: func() error {
195+
return ParamsMatch(stringParam, 42)
196+
},
197+
want: "mismatched argument type: expected string, got int",
198+
},
199+
{
200+
name: "interface{} ignored",
201+
fn: func() error {
202+
return ParamsMatch(interfaceParam, "", 23, 42)
203+
},
204+
want: "",
205+
},
206+
{
207+
name: "mixed params",
208+
fn: func() error {
209+
return ParamsMatch(mixedParams, 42, "")
210+
},
211+
want: "",
212+
},
213+
{
214+
name: "context",
215+
fn: func() error {
216+
return ParamsMatch(mixedParams, 42, "", 23)
217+
},
218+
want: "mismatched argument count: expected 2, got 3",
219+
},
220+
{
221+
name: "mixed params - wrong params",
222+
fn: func() error {
223+
return ParamsMatch(mixedParams, "", 42)
224+
},
225+
want: "mismatched argument type: expected int, got string",
226+
},
227+
}
228+
for _, tt := range tests {
229+
t.Run(tt.name, func(t *testing.T) {
230+
got := tt.fn()
231+
if tt.want == "" {
232+
require.NoError(t, got)
233+
} else {
234+
require.Error(t, got)
235+
require.Equal(t, tt.want, got.Error())
236+
}
237+
})
238+
}
239+
}

internal/fn/fn.go

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
package fn
22

33
import (
4-
"errors"
5-
"fmt"
64
"reflect"
75
"runtime"
86
"strings"
@@ -17,59 +15,3 @@ func Name(i interface{}) string {
1715

1816
return strings.TrimSuffix(fnName, "-fm")
1917
}
20-
21-
func ReturnTypeMatch[TResult any](fn interface{}) error {
22-
fnType := reflect.TypeOf(fn)
23-
if fnType.Kind() != reflect.Func {
24-
return errors.New("not a function")
25-
}
26-
27-
if fnType.NumOut() < 1 {
28-
return errors.New("function has no return value, must return at least (error) or (result, error)")
29-
}
30-
31-
if fnType.NumOut() > 2 {
32-
return errors.New("function has too many return values, must return at most (error) or (result, error)")
33-
}
34-
35-
errorPosition := 0
36-
if fnType.NumOut() == 2 {
37-
errorPosition = 1
38-
39-
t := *new(TResult)
40-
if fnType.Out(0) != reflect.TypeOf(t) {
41-
return fmt.Errorf("function must return %s, got %s", reflect.TypeOf(t), fnType.Out(0))
42-
}
43-
}
44-
45-
// Check if return is error
46-
if fnType.Out(errorPosition) != reflect.TypeOf((*error)(nil)).Elem() {
47-
return fmt.Errorf("function must return error, got %s", fnType.Out(errorPosition))
48-
}
49-
50-
return nil
51-
}
52-
53-
func ParamsMatch(fn interface{}, skip int, args ...interface{}) error {
54-
fnType := reflect.TypeOf(fn)
55-
if fnType.Kind() != reflect.Func {
56-
return errors.New("not a function")
57-
}
58-
59-
if fnType.NumIn() != skip+len(args) {
60-
return fmt.Errorf("mismatched argument count: expected %d, got %d", fnType.NumIn()-skip, len(args))
61-
}
62-
63-
for i, arg := range args {
64-
// if target is interface{} skip
65-
if fnType.In(skip+i).Kind() == reflect.Interface {
66-
continue
67-
}
68-
69-
if fnType.In(skip+i) != reflect.TypeOf(arg) {
70-
return fmt.Errorf("mismatched argument type: expected %s, got %s", fnType.In(skip+i), reflect.TypeOf(arg))
71-
}
72-
}
73-
74-
return nil
75-
}

0 commit comments

Comments
 (0)