Skip to content

Commit 2e989be

Browse files
authored
Merge pull request #165 from cschleiden/cschleiden/check-return-type
Check return types when scheduling activities and sub workflows
2 parents 53042a4 + efe7fb7 commit 2e989be

File tree

6 files changed

+165
-0
lines changed

6 files changed

+165
-0
lines changed

internal/fn/fn.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,17 @@ func Name(i interface{}) string {
1515

1616
return strings.TrimSuffix(fnName, "-fm")
1717
}
18+
19+
func ReturnTypeMatch[TResult any](fn interface{}) bool {
20+
fnType := reflect.TypeOf(fn)
21+
if fnType.Kind() != reflect.Func {
22+
return false
23+
}
24+
25+
if fnType.NumOut() == 1 {
26+
return true
27+
}
28+
29+
t := *new(TResult)
30+
return fnType.Out(0) == reflect.TypeOf(t)
31+
}

internal/fn/fn_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,64 @@ func Test_GetFunctionName(t *testing.T) {
5252
})
5353
}
5454
}
55+
56+
func intReturn() (int, error) {
57+
return 0, nil
58+
}
59+
60+
func stringReturn() (string, error) {
61+
return "", nil
62+
}
63+
64+
func errorReturn() error {
65+
return nil
66+
}
67+
68+
func TestReturnTypeMatch(t *testing.T) {
69+
tests := []struct {
70+
name string
71+
fn func() bool
72+
want bool
73+
}{
74+
{
75+
name: "int match",
76+
fn: func() bool {
77+
return ReturnTypeMatch[int](intReturn)
78+
},
79+
want: true,
80+
},
81+
{
82+
name: "string match",
83+
fn: func() bool {
84+
return ReturnTypeMatch[string](stringReturn)
85+
},
86+
want: true,
87+
},
88+
{
89+
name: "int mismatch",
90+
fn: func() bool {
91+
return ReturnTypeMatch[string](intReturn)
92+
},
93+
},
94+
{
95+
name: "no param",
96+
fn: func() bool {
97+
return ReturnTypeMatch[any](errorReturn)
98+
},
99+
want: true,
100+
},
101+
{
102+
name: "no param mismatch",
103+
fn: func() bool {
104+
return ReturnTypeMatch[int](errorReturn)
105+
},
106+
want: true,
107+
},
108+
}
109+
for _, tt := range tests {
110+
t.Run(tt.name, func(t *testing.T) {
111+
got := tt.fn()
112+
require.Equal(t, tt.want, got)
113+
})
114+
}
115+
}

workflow/activity.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ func executeActivity[TResult any](ctx Context, options ActivityOptions, attempt
3838
return f
3939
}
4040

41+
if !fn.ReturnTypeMatch[TResult](activity) {
42+
f.Set(*new(TResult), fmt.Errorf("activity return type does not match expected type"))
43+
return f
44+
}
45+
4146
cv := converter.GetConverter(ctx)
4247
inputs, err := a.ArgsToInputs(cv, args...)
4348
if err != nil {

workflow/activity_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package workflow
2+
3+
import (
4+
"testing"
5+
6+
"github.com/benbjohnson/clock"
7+
"github.com/cschleiden/go-workflows/internal/converter"
8+
"github.com/cschleiden/go-workflows/internal/core"
9+
"github.com/cschleiden/go-workflows/internal/logger"
10+
"github.com/cschleiden/go-workflows/internal/sync"
11+
"github.com/cschleiden/go-workflows/internal/workflowstate"
12+
"github.com/cschleiden/go-workflows/internal/workflowtracer"
13+
"github.com/stretchr/testify/require"
14+
"go.opentelemetry.io/otel/trace"
15+
)
16+
17+
func Test_executeActivity_ParamMismatch(t *testing.T) {
18+
a := func(ctx Context) (int, error) {
19+
return 42, nil
20+
}
21+
22+
ctx := sync.Background()
23+
ctx = converter.WithConverter(ctx, converter.DefaultConverter)
24+
ctx = workflowstate.WithWorkflowState(
25+
ctx,
26+
workflowstate.NewWorkflowState(core.NewWorkflowInstance("a", ""), logger.NewDefaultLogger(), clock.New()),
27+
)
28+
ctx = workflowtracer.WithWorkflowTracer(ctx, workflowtracer.New(trace.NewNoopTracerProvider().Tracer("test")))
29+
30+
c := sync.NewCoroutine(ctx, func(ctx sync.Context) error {
31+
f := executeActivity[string](ctx, DefaultActivityOptions, 1, a)
32+
_, err := f.Get(ctx)
33+
require.Error(t, err)
34+
35+
return nil
36+
})
37+
38+
c.Execute()
39+
require.True(t, c.Finished())
40+
}

workflow/subworkflow.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ func createSubWorkflowInstance[TResult any](ctx sync.Context, options SubWorkflo
4848
return f
4949
}
5050

51+
if !fn.ReturnTypeMatch[TResult](wf) {
52+
f.Set(*new(TResult), fmt.Errorf("subworkflow return type does not match expected type"))
53+
return f
54+
}
55+
5156
name := fn.Name(wf)
5257

5358
cv := converter.GetConverter(ctx)

workflow/subworkflow_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package workflow
2+
3+
import (
4+
"testing"
5+
6+
"github.com/benbjohnson/clock"
7+
"github.com/cschleiden/go-workflows/internal/converter"
8+
"github.com/cschleiden/go-workflows/internal/core"
9+
"github.com/cschleiden/go-workflows/internal/logger"
10+
"github.com/cschleiden/go-workflows/internal/sync"
11+
"github.com/cschleiden/go-workflows/internal/workflowstate"
12+
"github.com/cschleiden/go-workflows/internal/workflowtracer"
13+
"github.com/stretchr/testify/require"
14+
"go.opentelemetry.io/otel/trace"
15+
)
16+
17+
func Test_createSubWorkflowInstance_ParamMismatch(t *testing.T) {
18+
wf := func(ctx Context) (int, error) {
19+
return 42, nil
20+
}
21+
22+
ctx := sync.Background()
23+
ctx = converter.WithConverter(ctx, converter.DefaultConverter)
24+
ctx = workflowstate.WithWorkflowState(
25+
ctx,
26+
workflowstate.NewWorkflowState(core.NewWorkflowInstance("a", ""), logger.NewDefaultLogger(), clock.New()),
27+
)
28+
ctx = workflowtracer.WithWorkflowTracer(ctx, workflowtracer.New(trace.NewNoopTracerProvider().Tracer("test")))
29+
30+
c := sync.NewCoroutine(ctx, func(ctx sync.Context) error {
31+
f := createSubWorkflowInstance[string](ctx, DefaultSubWorkflowOptions, 1, wf)
32+
_, err := f.Get(ctx)
33+
require.Error(t, err)
34+
35+
return nil
36+
})
37+
38+
c.Execute()
39+
require.True(t, c.Finished())
40+
}

0 commit comments

Comments
 (0)