Skip to content

Commit eff6532

Browse files
authored
feat(worker): Added middleware support for worker registration (#380)
1 parent d5a7e12 commit eff6532

File tree

8 files changed

+245
-1
lines changed

8 files changed

+245
-1
lines changed

worker/README.md

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* [Documentation](#documentation)
1414
* [Workers](#workers)
1515
* [WorkerPool](#workerpool)
16+
* [Middlewares](#middlewares)
1617
* [Logging](#logging)
1718
* [Tracing](#tracing)
1819
* [Metrics](#metrics)
@@ -156,6 +157,115 @@ func main() {
156157
}
157158
```
158159

160+
### Middlewares
161+
162+
This module provides middleware support for workers, allowing you to add behaviors without modifying the worker's core implementation.
163+
164+
Middlewares wrap a worker's `Run` method and can perform actions before and after the worker execution, or even modify the execution flow.
165+
166+
#### Using Middlewares
167+
168+
Here's an example of a middleware that returns an error if a worker takes too long to execute:
169+
170+
```go
171+
package main
172+
173+
import (
174+
"context"
175+
"errors"
176+
"time"
177+
178+
"github.com/ankorstore/yokai/worker"
179+
)
180+
181+
// SimpleWorker is a basic worker that sleeps for a specified duration
182+
type SimpleWorker struct {
183+
sleepDuration time.Duration
184+
}
185+
186+
// NewSimpleWorker creates a new SimpleWorker
187+
func NewSimpleWorker(sleepDuration time.Duration) *SimpleWorker {
188+
return &SimpleWorker{
189+
sleepDuration: sleepDuration,
190+
}
191+
}
192+
193+
// Name returns the worker name
194+
func (w *SimpleWorker) Name() string {
195+
return "simple-worker"
196+
}
197+
198+
// Run executes the worker
199+
func (w *SimpleWorker) Run(ctx context.Context) error {
200+
// Simulate work by sleeping
201+
time.Sleep(w.sleepDuration)
202+
203+
return nil
204+
}
205+
206+
// TimeoutMiddleware implements the worker.Middleware interface
207+
type TimeoutMiddleware struct {
208+
timeout time.Duration
209+
}
210+
211+
// NewTimeoutMiddleware creates a new TimeoutMiddleware with the specified timeout
212+
func NewTimeoutMiddleware(timeout time.Duration) *TimeoutMiddleware {
213+
return &TimeoutMiddleware{
214+
timeout: timeout,
215+
}
216+
}
217+
218+
// Name returns the middleware name
219+
func (m *TimeoutMiddleware) Name() string {
220+
return "timeout-middleware"
221+
}
222+
223+
// Handle returns the middleware function
224+
func (m *TimeoutMiddleware) Handle() worker.MiddlewareFunc {
225+
return func(next worker.HandlerFunc) worker.HandlerFunc {
226+
return func(ctx context.Context) error {
227+
// Create a timeout context
228+
timeoutCtx, cancel := context.WithTimeout(ctx, m.timeout)
229+
defer cancel()
230+
231+
// Create a channel to receive the result of the worker execution
232+
done := make(chan error)
233+
234+
// Execute the worker in a goroutine
235+
go func() {
236+
done <- next(timeoutCtx)
237+
}()
238+
239+
// Wait for either the worker to complete or the timeout to occur
240+
select {
241+
case err := <-done:
242+
return err
243+
case <-timeoutCtx.Done():
244+
if errors.Is(timeoutCtx.Err(), context.DeadlineExceeded) {
245+
return errors.New("worker execution timed out")
246+
}
247+
return timeoutCtx.Err()
248+
}
249+
}
250+
}
251+
}
252+
253+
func main() {
254+
// Create a worker pool with a worker and the timeout middleware
255+
pool, _ := worker.NewDefaultWorkerPoolFactory().Create(
256+
worker.WithWorker(
257+
NewSimpleWorker(10 * time.Second), // Worker that takes 10 seconds to complete
258+
worker.WithMiddlewares(NewTimeoutMiddleware(5 * time.Second)), // Middleware that times out after 5 seconds
259+
),
260+
)
261+
262+
// Start the pool
263+
pool.Start(context.Background())
264+
265+
// The worker will be interrupted after 5 seconds with a timeout error
266+
}
267+
```
268+
159269
### Logging
160270

161271
You can use the [CtxLogger()](context.go) function to retrieve the

worker/execution.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ type WorkerExecution struct {
1515
maxExecutionsAttempts int
1616
deferredStartThreshold float64
1717
events []*WorkerExecutionEvent
18+
middlewares []Middleware
1819
}
1920

2021
// NewWorkerExecution returns a new [WorkerExecution].
@@ -27,6 +28,7 @@ func NewWorkerExecution(id string, name string, options ExecutionOptions) *Worke
2728
maxExecutionsAttempts: options.MaxExecutionsAttempts,
2829
deferredStartThreshold: options.DeferredStartThreshold,
2930
events: []*WorkerExecutionEvent{},
31+
middlewares: options.Middlewares,
3032
}
3133
}
3234

@@ -157,3 +159,10 @@ func (e *WorkerExecution) HasEvent(message string) bool {
157159

158160
return false
159161
}
162+
163+
func (e *WorkerExecution) Middlewares() []Middleware {
164+
e.mutex.Lock()
165+
defer e.mutex.Unlock()
166+
167+
return e.middlewares
168+
}

worker/middleware.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package worker
2+
3+
import "context"
4+
5+
// Middleware is the interface to implement to provide worker middlewares.
6+
type Middleware interface {
7+
Name() string
8+
Handle() MiddlewareFunc
9+
}
10+
11+
// MiddlewareFunc wraps handlers in the middleware chain to perform operations before and after execution.
12+
type MiddlewareFunc func(next HandlerFunc) HandlerFunc
13+
14+
// HandlerFunc executes a worker's Run method or another middleware in the chain with context information.
15+
type HandlerFunc func(ctx context.Context) error

worker/middleware_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package worker_test
2+
3+
import "github.com/ankorstore/yokai/worker"
4+
5+
type TestMiddleware struct {
6+
Func worker.MiddlewareFunc
7+
}
8+
9+
func (m *TestMiddleware) Name() string {
10+
return "TestMiddleware"
11+
}
12+
13+
func (m *TestMiddleware) Handle() worker.MiddlewareFunc {
14+
return m.Func
15+
}

worker/option.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,15 @@ func WithWorker(worker Worker, options ...WorkerExecutionOption) WorkerPoolOptio
7171
type ExecutionOptions struct {
7272
DeferredStartThreshold float64
7373
MaxExecutionsAttempts int
74+
Middlewares []Middleware
7475
}
7576

7677
// DefaultWorkerExecutionOptions are the default options for the [Worker] executions.
7778
func DefaultWorkerExecutionOptions() ExecutionOptions {
7879
return ExecutionOptions{
7980
DeferredStartThreshold: DefaultDeferredStartThreshold,
8081
MaxExecutionsAttempts: DefaultMaxExecutionsAttempts,
82+
Middlewares: []Middleware{},
8183
}
8284
}
8385

@@ -97,3 +99,13 @@ func WithMaxExecutionsAttempts(l int) WorkerExecutionOption {
9799
o.MaxExecutionsAttempts = l
98100
}
99101
}
102+
103+
// WithMiddlewares is used to add middlewares to a worker registration.
104+
func WithMiddlewares(middlewares ...Middleware) WorkerExecutionOption {
105+
return func(o *ExecutionOptions) {
106+
if o.Middlewares == nil {
107+
o.Middlewares = []Middleware{}
108+
}
109+
o.Middlewares = append(o.Middlewares, middlewares...)
110+
}
111+
}

worker/option_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,15 @@ func TestWorkerExecutionOptionsWithMaxExecutionsAttempts(t *testing.T) {
7979

8080
assert.Equal(t, 2, opt.MaxExecutionsAttempts)
8181
}
82+
83+
func TestWorkerExecutionOptionsWithMiddlewares(t *testing.T) {
84+
t.Parallel()
85+
86+
opt := worker.DefaultWorkerExecutionOptions()
87+
88+
assert.Len(t, opt.Middlewares, 0)
89+
90+
worker.WithMiddlewares(&TestMiddleware{}, &TestMiddleware{})(&opt)
91+
92+
assert.Len(t, opt.Middlewares, 2)
93+
}

worker/pool.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,14 @@ func (p *WorkerPool) startWorkerRegistration(ctx context.Context, registration *
214214

215215
p.metrics.IncrementWorkerExecutionStart(registration.Worker().Name())
216216

217-
if err := registration.Worker().Run(ctx); err != nil {
217+
runFunc := registration.Worker().Run
218+
219+
for i := len(execution.Middlewares()) - 1; i >= 0; i-- {
220+
middleware := execution.Middlewares()[i]
221+
runFunc = middleware.Handle()(runFunc)
222+
}
223+
224+
if err := runFunc(ctx); err != nil {
218225
message = fmt.Sprintf(
219226
"stopping execution attempt %d/%d with error: %v",
220227
workerExecution.CurrentExecutionAttempt(),

worker/pool_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,67 @@ func TestExecutionWithRestartingPanicWorker(t *testing.T) {
544544
)
545545
assert.NoError(t, err)
546546
}
547+
548+
func TestExecutionWithMiddlewares(t *testing.T) {
549+
t.Parallel()
550+
551+
// Create a slice to track the execution order
552+
executionOrder := []string{}
553+
554+
// Create a worker
555+
testWorker := workers.NewClassicWorker()
556+
557+
// Create the first middleware
558+
middleware1 := &TestMiddleware{
559+
Func: func(next worker.HandlerFunc) worker.HandlerFunc {
560+
return func(ctx context.Context) error {
561+
executionOrder = append(executionOrder, "middleware1 before")
562+
err := next(ctx)
563+
executionOrder = append(executionOrder, "middleware1 after")
564+
565+
return err
566+
}
567+
},
568+
}
569+
570+
// Create the second middleware
571+
middleware2 := &TestMiddleware{
572+
Func: func(next worker.HandlerFunc) worker.HandlerFunc {
573+
return func(ctx context.Context) error {
574+
executionOrder = append(executionOrder, "middleware2 before")
575+
err := next(ctx)
576+
executionOrder = append(executionOrder, "middleware2 after")
577+
578+
return err
579+
}
580+
},
581+
}
582+
583+
// Create a worker pool with the worker and middlewares
584+
pool, err := worker.NewDefaultWorkerPoolFactory().Create(
585+
worker.WithWorker(
586+
testWorker,
587+
worker.WithMiddlewares(middleware1, middleware2),
588+
),
589+
)
590+
assert.NoError(t, err)
591+
592+
// Start the worker pool
593+
err = pool.Start(context.Background())
594+
assert.NoError(t, err)
595+
596+
// Wait for the worker to complete
597+
time.Sleep(10 * time.Millisecond)
598+
599+
// Stop the worker pool
600+
err = pool.Stop()
601+
assert.NoError(t, err)
602+
603+
// Verify the execution order
604+
assert.Equal(t, []string{
605+
"middleware1 before",
606+
"middleware2 before",
607+
"middleware2 after",
608+
"middleware1 after",
609+
}, executionOrder)
610+
}

0 commit comments

Comments
 (0)