Skip to content

Commit 72d3fae

Browse files
authored
Merge pull request #67 from cschleiden/registry-error-checks
Better error checking when registering workflows and activities
2 parents 4cdf42f + 44c4f3d commit 72d3fae

File tree

3 files changed

+173
-18
lines changed

3 files changed

+173
-18
lines changed

internal/args/args.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func InputsToArgs(c converter.Converter, fn reflect.Value, inputs []payload.Payl
3737
argT := activityFnT.In(i)
3838

3939
// Insert context if requested
40-
if i == 0 && (isOwnContext(argT) || isContext(argT)) {
40+
if i == 0 && (IsOwnContext(argT) || isContext(argT)) {
4141
addContext = true
4242
continue
4343
}
@@ -66,7 +66,7 @@ func InputsToArgs(c converter.Converter, fn reflect.Value, inputs []payload.Payl
6666
return args, addContext, nil
6767
}
6868

69-
func isOwnContext(inType reflect.Type) bool {
69+
func IsOwnContext(inType reflect.Type) bool {
7070
contextElem := reflect.TypeOf((*sync.Context)(nil)).Elem()
7171
return inType != nil && inType.Implements(contextElem)
7272
}

internal/workflow/registry.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"reflect"
66
"sync"
77

8+
"github.com/cschleiden/go-workflows/internal/args"
89
"github.com/cschleiden/go-workflows/internal/fn"
910
)
1011

@@ -25,10 +26,53 @@ func NewRegistry() *Registry {
2526
}
2627
}
2728

29+
type ErrInvalidWorkflow struct {
30+
msg string
31+
}
32+
33+
func (e *ErrInvalidWorkflow) Error() string {
34+
return e.msg
35+
}
36+
37+
type ErrInvalidActivity struct {
38+
msg string
39+
}
40+
41+
func (e *ErrInvalidActivity) Error() string {
42+
return e.msg
43+
}
44+
2845
func (r *Registry) RegisterWorkflow(workflow Workflow) error {
2946
r.Lock()
3047
defer r.Unlock()
3148

49+
wfType := reflect.TypeOf(workflow)
50+
if wfType.Kind() != reflect.Func {
51+
return &ErrInvalidWorkflow{"workflow is not a function"}
52+
}
53+
54+
if wfType.NumIn() == 0 {
55+
return &ErrInvalidWorkflow{"workflow does not accept context parameter"}
56+
}
57+
58+
if !args.IsOwnContext(wfType.In(0)) {
59+
return &ErrInvalidWorkflow{"workflow does not accept context as first parameter"}
60+
}
61+
62+
if wfType.NumOut() == 0 {
63+
return &ErrInvalidWorkflow{"workflow must return error"}
64+
}
65+
66+
if wfType.NumOut() > 2 {
67+
return &ErrInvalidWorkflow{"workflow must return at most two values"}
68+
}
69+
70+
errType := reflect.TypeOf((*error)(nil)).Elem()
71+
if (wfType.NumOut() == 1 && !wfType.Out(0).Implements(errType)) ||
72+
(wfType.NumOut() == 2 && !wfType.Out(1).Implements(errType)) {
73+
return &ErrInvalidWorkflow{"workflow must return error as last return value"}
74+
}
75+
3276
name := fn.Name(workflow)
3377
r.workflowMap[name] = workflow
3478

@@ -47,6 +91,10 @@ func (r *Registry) RegisterActivity(activity interface{}) error {
4791
}
4892

4993
// Activity as function
94+
if err := checkActivity(reflect.TypeOf(activity)); err != nil {
95+
return err
96+
}
97+
5098
name := fn.Name(activity)
5199
r.activityMap[name] = activity
52100

@@ -66,13 +114,34 @@ func (r *Registry) registerActivitiesFromStruct(a interface{}) error {
66114
continue
67115
}
68116

117+
if err := checkActivity(mt.Type); err != nil {
118+
return err
119+
}
120+
69121
name := mt.Name
70122
r.activityMap[name] = mv.Interface()
71123
}
72124

73125
return nil
74126
}
75127

128+
func checkActivity(actType reflect.Type) error {
129+
if actType.Kind() != reflect.Func {
130+
return &ErrInvalidActivity{"activity not a func"}
131+
}
132+
133+
if actType.NumOut() == 0 {
134+
return &ErrInvalidActivity{"activity must return error"}
135+
}
136+
137+
errType := reflect.TypeOf((*error)(nil)).Elem()
138+
if !actType.Out(actType.NumOut() - 1).Implements(errType) {
139+
return &ErrInvalidWorkflow{"activity must return error as last return value"}
140+
}
141+
142+
return nil
143+
}
144+
76145
func (r *Registry) GetWorkflow(name string) (Workflow, error) {
77146
r.Lock()
78147
defer r.Unlock()

internal/workflow/registry_test.go

Lines changed: 102 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,79 @@ func reg_workflow1(ctx sync.Context) error {
1313
return nil
1414
}
1515

16-
func Test_WorkflowRegistration(t *testing.T) {
17-
r := NewRegistry()
18-
require.NotNil(t, r)
19-
20-
err := r.RegisterWorkflow(reg_workflow1)
21-
require.NoError(t, err)
22-
23-
x, err := r.GetWorkflow("reg_workflow1")
24-
require.NoError(t, err)
25-
26-
fn, ok := x.(func(context sync.Context) error)
27-
require.True(t, ok)
28-
require.NotNil(t, fn)
29-
30-
err = fn(sync.Background())
31-
require.NoError(t, err)
16+
func TestRegistry_RegisterWorkflow(t *testing.T) {
17+
type args struct {
18+
workflow Workflow
19+
}
20+
tests := []struct {
21+
name string
22+
args args
23+
wantName string
24+
wantErr bool
25+
}{
26+
{
27+
name: "valid workflow",
28+
args: args{
29+
workflow: reg_workflow1,
30+
},
31+
wantName: "reg_workflow1",
32+
},
33+
{
34+
name: "valid workflow with results",
35+
args: args{
36+
workflow: func(ctx sync.Context) (int, error) { return 42, nil },
37+
},
38+
},
39+
{
40+
name: "valid workflow with multiple parameters",
41+
args: args{
42+
workflow: func(ctx sync.Context, a, b int) (int, error) { return 42, nil },
43+
},
44+
},
45+
{
46+
name: "missing parameter",
47+
args: args{
48+
workflow: func(ctx context.Context) {},
49+
},
50+
wantErr: true,
51+
},
52+
{
53+
name: "missing error result",
54+
args: args{
55+
workflow: func(ctx sync.Context) {},
56+
},
57+
wantErr: true,
58+
},
59+
{
60+
name: "missing error with results",
61+
args: args{
62+
workflow: func(ctx sync.Context) int { return 42 },
63+
},
64+
wantErr: true,
65+
},
66+
{
67+
name: "missing error with results",
68+
args: args{
69+
workflow: func(ctx sync.Context) int { return 42 },
70+
},
71+
wantErr: true,
72+
},
73+
}
74+
for _, tt := range tests {
75+
t.Run(tt.name, func(t *testing.T) {
76+
r := NewRegistry()
77+
if err := r.RegisterWorkflow(tt.args.workflow); (err != nil) != tt.wantErr {
78+
t.Errorf("Registry.RegisterWorkflow() error = %v, wantErr %v", err, tt.wantErr)
79+
t.FailNow()
80+
}
81+
82+
if tt.wantName != "" {
83+
x, err := r.GetWorkflow(tt.wantName)
84+
require.NoError(t, err)
85+
require.NotNil(t, x)
86+
}
87+
})
88+
}
3289
}
3390

3491
func reg_activity(ctx context.Context) error {
@@ -53,6 +110,17 @@ func Test_ActivityRegistration(t *testing.T) {
53110
require.NoError(t, err)
54111
}
55112

113+
func reg_activity_invalid(ctx context.Context) {
114+
}
115+
116+
func Test_ActivityRegistration_Invalid(t *testing.T) {
117+
r := NewRegistry()
118+
require.NotNil(t, r)
119+
120+
err := r.RegisterActivity(reg_activity_invalid)
121+
require.Error(t, err)
122+
}
123+
56124
type reg_activities struct {
57125
SomeValue string
58126
}
@@ -92,3 +160,21 @@ func Test_ActivityRegistrationOnStruct(t *testing.T) {
92160
require.NoError(t, err)
93161
require.Equal(t, "test", v)
94162
}
163+
164+
type reg_invalid_activities struct {
165+
SomeValue string
166+
}
167+
168+
func (r *reg_invalid_activities) Activity1(ctx context.Context) {
169+
}
170+
171+
func Test_ActivityRegistrationOnStruct_Invalid(t *testing.T) {
172+
r := NewRegistry()
173+
require.NotNil(t, r)
174+
175+
a := &reg_invalid_activities{
176+
SomeValue: "test",
177+
}
178+
err := r.RegisterActivity(a)
179+
require.Error(t, err)
180+
}

0 commit comments

Comments
 (0)