Skip to content

Commit 44c4f3d

Browse files
committed
Error check when registering activities
1 parent f243911 commit 44c4f3d

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

internal/workflow/registry.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ func (e *ErrInvalidWorkflow) Error() string {
3434
return e.msg
3535
}
3636

37+
type ErrInvalidActivity struct {
38+
msg string
39+
}
40+
41+
func (e *ErrInvalidActivity) Error() string {
42+
return e.msg
43+
}
44+
3745
func (r *Registry) RegisterWorkflow(workflow Workflow) error {
3846
r.Lock()
3947
defer r.Unlock()
@@ -83,6 +91,10 @@ func (r *Registry) RegisterActivity(activity interface{}) error {
8391
}
8492

8593
// Activity as function
94+
if err := checkActivity(reflect.TypeOf(activity)); err != nil {
95+
return err
96+
}
97+
8698
name := fn.Name(activity)
8799
r.activityMap[name] = activity
88100

@@ -102,13 +114,34 @@ func (r *Registry) registerActivitiesFromStruct(a interface{}) error {
102114
continue
103115
}
104116

117+
if err := checkActivity(mt.Type); err != nil {
118+
return err
119+
}
120+
105121
name := mt.Name
106122
r.activityMap[name] = mv.Interface()
107123
}
108124

109125
return nil
110126
}
111127

128+
func checkActivity(actType reflect.Type) error {
129+
if actType.Kind() != reflect.Func {
130+
return &ErrInvalidActivity{"activity not a func"}
131+
}
132+
133+
if actType.NumOut() == 0 {
134+
return &ErrInvalidActivity{"activity must return error"}
135+
}
136+
137+
errType := reflect.TypeOf((*error)(nil)).Elem()
138+
if !actType.Out(actType.NumOut() - 1).Implements(errType) {
139+
return &ErrInvalidWorkflow{"activity must return error as last return value"}
140+
}
141+
142+
return nil
143+
}
144+
112145
func (r *Registry) GetWorkflow(name string) (Workflow, error) {
113146
r.Lock()
114147
defer r.Unlock()

internal/workflow/registry_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ func Test_ActivityRegistration(t *testing.T) {
110110
require.NoError(t, err)
111111
}
112112

113+
func reg_activity_invalid(ctx context.Context) {
114+
}
115+
116+
func Test_ActivityRegistration_Invalid(t *testing.T) {
117+
r := NewRegistry()
118+
require.NotNil(t, r)
119+
120+
err := r.RegisterActivity(reg_activity_invalid)
121+
require.Error(t, err)
122+
}
123+
113124
type reg_activities struct {
114125
SomeValue string
115126
}
@@ -149,3 +160,21 @@ func Test_ActivityRegistrationOnStruct(t *testing.T) {
149160
require.NoError(t, err)
150161
require.Equal(t, "test", v)
151162
}
163+
164+
type reg_invalid_activities struct {
165+
SomeValue string
166+
}
167+
168+
func (r *reg_invalid_activities) Activity1(ctx context.Context) {
169+
}
170+
171+
func Test_ActivityRegistrationOnStruct_Invalid(t *testing.T) {
172+
r := NewRegistry()
173+
require.NotNil(t, r)
174+
175+
a := &reg_invalid_activities{
176+
SomeValue: "test",
177+
}
178+
err := r.RegisterActivity(a)
179+
require.Error(t, err)
180+
}

0 commit comments

Comments
 (0)