Skip to content

Commit 6eafe96

Browse files
authored
Merge pull request #274 from cschleiden/unify-worker
Unify worker code
2 parents c54d9f0 + 6be2817 commit 6eafe96

File tree

10 files changed

+350
-380
lines changed

10 files changed

+350
-380
lines changed

backend/redis/expire_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func Test_AutoExpiration(t *testing.T) {
2525
b := setup()
2626

2727
c := client.New(b)
28-
w := worker.New(b, &worker.DefaultWorkerOptions)
28+
w := worker.New(b, nil)
2929

3030
ctx, cancel := context.WithCancel(context.Background())
3131

@@ -66,7 +66,7 @@ func Test_AutoExpiration_SubWorkflow(t *testing.T) {
6666
b := setup()
6767

6868
c := client.New(b)
69-
w := worker.New(b, &worker.DefaultWorkerOptions)
69+
w := worker.New(b, nil)
7070

7171
ctx, cancel := context.WithCancel(context.Background())
7272

backend/test/e2e.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ func EndToEndBackendTest(t *testing.T, setup func(options ...backend.BackendOpti
704704
}
705705
}
706706

707-
options := worker.DefaultWorkerOptions
707+
options := worker.DefaultOptions
708708

709709
// Run with cache
710710
run("", options)

bench/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func main() {
4646
)
4747
defer ba.Close()
4848

49-
wo := worker.DefaultWorkerOptions
49+
wo := worker.DefaultOptions
5050
wo.WorkflowExecutorCacheSize = *cacheSize
5151
w := worker.New(ba, &wo)
5252

internal/worker/activity.go

Lines changed: 37 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package worker
22

33
import (
44
"context"
5-
"errors"
65
"log/slog"
7-
"sync"
86
"time"
97

108
"github.com/benbjohnson/clock"
@@ -14,197 +12,82 @@ import (
1412
"github.com/cschleiden/go-workflows/backend/payload"
1513
"github.com/cschleiden/go-workflows/internal/activity"
1614
"github.com/cschleiden/go-workflows/internal/metrickeys"
17-
mi "github.com/cschleiden/go-workflows/internal/metrics"
15+
im "github.com/cschleiden/go-workflows/internal/metrics"
1816
"github.com/cschleiden/go-workflows/internal/workflow"
1917
"github.com/cschleiden/go-workflows/internal/workflowerrors"
2018
)
2119

22-
type ActivityWorker struct {
23-
backend backend.Backend
20+
func NewActivityWorker(b backend.Backend, registry *workflow.Registry, clock clock.Clock, options WorkerOptions) *Worker[backend.ActivityTask, history.Event] {
21+
ae := activity.NewExecutor(b.Logger(), b.Tracer(), b.Converter(), b.ContextPropagators(), registry)
2422

25-
options *Options
26-
27-
activityTaskQueue chan *backend.ActivityTask
28-
activityTaskExecutor *activity.Executor
29-
30-
wg sync.WaitGroup
31-
pollersWg sync.WaitGroup
32-
33-
clock clock.Clock
34-
logger *slog.Logger
35-
}
36-
37-
func NewActivityWorker(b backend.Backend, registry *workflow.Registry, clock clock.Clock, options *Options) *ActivityWorker {
38-
return &ActivityWorker{
39-
backend: b,
40-
41-
options: options,
42-
43-
activityTaskQueue: make(chan *backend.ActivityTask),
44-
activityTaskExecutor: activity.NewExecutor(b.Logger(), b.Tracer(), b.Converter(), b.ContextPropagators(), registry),
45-
46-
clock: clock,
47-
logger: b.Logger(),
48-
}
49-
}
50-
51-
func (aw *ActivityWorker) Start(ctx context.Context) error {
52-
aw.pollersWg.Add(aw.options.ActivityPollers)
53-
54-
for i := 0; i < aw.options.ActivityPollers; i++ {
55-
go aw.runPoll(ctx)
23+
tw := &ActivityTaskWorker{
24+
backend: b,
25+
activityTaskExecutor: ae,
26+
clock: clock,
27+
logger: b.Logger(),
5628
}
5729

58-
go aw.runDispatcher()
59-
60-
return nil
30+
return NewWorker[backend.ActivityTask, history.Event](b, tw, &options)
6131
}
6232

63-
func (aw *ActivityWorker) WaitForCompletion() error {
64-
// Wait for task pollers to finish
65-
aw.pollersWg.Wait()
66-
67-
// Wait for tasks to finish
68-
aw.wg.Wait()
69-
close(aw.activityTaskQueue)
70-
71-
return nil
72-
}
73-
74-
func (aw *ActivityWorker) runPoll(ctx context.Context) {
75-
defer aw.pollersWg.Done()
76-
77-
ticker := time.NewTicker(aw.options.ActivityPollingInterval)
78-
defer ticker.Stop()
79-
for {
80-
task, err := aw.poll(ctx, 30*time.Second)
81-
if err != nil {
82-
aw.logger.ErrorContext(ctx, "error while polling for activity task", "error", err)
83-
}
84-
if task != nil {
85-
aw.wg.Add(1)
86-
aw.activityTaskQueue <- task
87-
continue // check for new tasks right away
88-
}
89-
90-
select {
91-
case <-ctx.Done():
92-
return
93-
case <-ticker.C:
94-
}
95-
}
33+
type ActivityTaskWorker struct {
34+
backend backend.Backend
35+
activityTaskExecutor *activity.Executor
36+
clock clock.Clock
37+
logger *slog.Logger
9638
}
9739

98-
func (aw *ActivityWorker) runDispatcher() {
99-
var sem chan struct{}
100-
if aw.options.MaxParallelActivityTasks > 0 {
101-
sem = make(chan struct{}, aw.options.MaxParallelActivityTasks)
40+
func (atw *ActivityTaskWorker) Complete(ctx context.Context, event *history.Event, task *backend.ActivityTask) error {
41+
if err := atw.backend.CompleteActivityTask(ctx, task.WorkflowInstance, task.ID, event); err != nil {
42+
atw.backend.Logger().Error("completing activity task", "error", err)
10243
}
10344

104-
for task := range aw.activityTaskQueue {
105-
if sem != nil {
106-
sem <- struct{}{}
107-
}
108-
109-
task := task
110-
111-
go func() {
112-
defer aw.wg.Done()
113-
114-
// Create new context to allow activities to complete when root context is canceled
115-
taskCtx := context.Background()
116-
aw.handleTask(taskCtx, task)
117-
118-
if sem != nil {
119-
<-sem
120-
}
121-
}()
122-
}
45+
return nil
12346
}
12447

125-
func (aw *ActivityWorker) handleTask(ctx context.Context, task *backend.ActivityTask) {
48+
func (atw *ActivityTaskWorker) Execute(ctx context.Context, task *backend.ActivityTask) (*history.Event, error) {
12649
a := task.Event.Attributes.(*history.ActivityScheduledAttributes)
127-
ametrics := aw.backend.Metrics().WithTags(metrics.Tags{metrickeys.ActivityName: a.Name})
50+
ametrics := atw.backend.Metrics().WithTags(metrics.Tags{metrickeys.ActivityName: a.Name})
12851

12952
// Record how long this task was in the queue
13053
scheduledAt := task.Event.Timestamp
13154
timeInQueue := time.Since(scheduledAt)
13255
ametrics.Distribution(metrickeys.ActivityTaskDelay, metrics.Tags{}, float64(timeInQueue/time.Millisecond))
13356

134-
// Start heartbeat while activity is running
135-
if aw.options.ActivityHeartbeatInterval > 0 {
136-
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
137-
defer cancelHeartbeat()
138-
139-
go func(ctx context.Context) {
140-
t := time.NewTicker(aw.options.ActivityHeartbeatInterval)
141-
defer t.Stop()
142-
143-
for {
144-
select {
145-
case <-ctx.Done():
146-
return
147-
case <-t.C:
148-
if err := aw.backend.ExtendActivityTask(ctx, task.ID); err != nil {
149-
if !errors.Is(err, context.Canceled) {
150-
aw.backend.Logger().Error("extending activity task", "error", err)
151-
panic("extending activity task")
152-
}
153-
}
154-
}
155-
}
156-
}(heartbeatCtx)
157-
}
158-
159-
timer := mi.NewTimer(ametrics, metrickeys.ActivityTaskProcessed, metrics.Tags{})
57+
timer := im.NewTimer(ametrics, metrickeys.ActivityTaskProcessed, metrics.Tags{})
16058
defer timer.Stop()
16159

162-
result, err := aw.activityTaskExecutor.ExecuteActivity(ctx, task)
163-
event := aw.resultToEvent(task.Event.ScheduleEventID, result, err)
60+
result, err := atw.activityTaskExecutor.ExecuteActivity(ctx, task)
61+
event := atw.resultToEvent(task.Event.ScheduleEventID, result, err)
16462

165-
if err := aw.backend.CompleteActivityTask(ctx, task.WorkflowInstance, task.ID, event); err != nil {
166-
aw.backend.Logger().Error("completing activity task", "error", err)
167-
panic("completing activity task")
168-
}
63+
return event, nil
64+
}
65+
66+
func (atw *ActivityTaskWorker) Extend(ctx context.Context, task *backend.ActivityTask) error {
67+
return atw.backend.ExtendActivityTask(ctx, task.ID)
16968
}
17069

171-
func (aw *ActivityWorker) resultToEvent(ScheduleEventID int64, result payload.Payload, err error) *history.Event {
70+
func (atw *ActivityTaskWorker) Get(ctx context.Context) (*backend.ActivityTask, error) {
71+
return atw.backend.GetActivityTask(ctx)
72+
}
73+
74+
func (atw *ActivityTaskWorker) resultToEvent(scheduleEventID int64, result payload.Payload, err error) *history.Event {
17275
if err != nil {
17376
return history.NewPendingEvent(
174-
aw.clock.Now(),
77+
atw.clock.Now(),
17578
history.EventType_ActivityFailed,
17679
&history.ActivityFailedAttributes{
17780
Error: workflowerrors.FromError(err),
17881
},
179-
history.ScheduleEventID(ScheduleEventID),
82+
history.ScheduleEventID(scheduleEventID),
18083
)
18184
}
18285

18386
return history.NewPendingEvent(
184-
aw.clock.Now(),
87+
atw.clock.Now(),
18588
history.EventType_ActivityCompleted,
18689
&history.ActivityCompletedAttributes{
18790
Result: result,
18891
},
189-
history.ScheduleEventID(ScheduleEventID))
190-
}
191-
192-
func (aw *ActivityWorker) poll(ctx context.Context, timeout time.Duration) (*backend.ActivityTask, error) {
193-
if timeout == 0 {
194-
timeout = 30 * time.Second
195-
}
196-
197-
ctx, cancel := context.WithTimeout(ctx, timeout)
198-
defer cancel()
199-
200-
task, err := aw.backend.GetActivityTask(ctx)
201-
if err != nil {
202-
if errors.Is(err, context.DeadlineExceeded) {
203-
return nil, nil
204-
}
205-
206-
return nil, err
207-
}
208-
209-
return task, nil
92+
history.ScheduleEventID(scheduleEventID))
21093
}

0 commit comments

Comments
 (0)