Skip to content

Commit a9bc485

Browse files
authored
Merge pull request #396 from cschleiden/stop-task-processing-heartbeat-failure
Stop processing a task when the heartbeat fails
2 parents 9793749 + 8e898a0 commit a9bc485

File tree

2 files changed

+69
-16
lines changed

2 files changed

+69
-16
lines changed

internal/worker/worker.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,30 @@ func (w *Worker[Task, TaskResult]) dispatcher() {
178178
}
179179

180180
func (w *Worker[Task, TaskResult]) handle(ctx context.Context, t *Task) error {
181+
// Create a cancelable context for this task so we can abort processing on heartbeat failure
182+
taskCtx, cancelTask := context.WithCancel(ctx)
183+
defer cancelTask()
184+
181185
if w.options.HeartbeatInterval > 0 {
182-
// Start heartbeat while processing task
183-
heartbeatCtx, cancelHeartbeat := context.WithCancel(ctx)
184-
defer cancelHeartbeat()
185-
go w.heartbeatTask(heartbeatCtx, t)
186+
// Start heartbeat while processing task.
187+
// If Extend fails we assume we might not own the task anymore and cancel processing.
188+
go w.heartbeatTask(taskCtx, t, cancelTask)
186189
}
187190

188-
result, err := w.tw.Execute(ctx, t)
191+
result, err := w.tw.Execute(taskCtx, t)
189192
if err != nil {
193+
// If execution was canceled (e.g., because heartbeat extend failed), abort without completing.
194+
if errors.Is(err, context.Canceled) {
195+
return err
196+
}
197+
190198
return fmt.Errorf("executing task: %w", err)
191199
}
192200

193201
return w.tw.Complete(ctx, result, t)
194202
}
195203

196-
func (w *Worker[Task, TaskResult]) heartbeatTask(ctx context.Context, task *Task) {
204+
func (w *Worker[Task, TaskResult]) heartbeatTask(ctx context.Context, task *Task, cancel func()) {
197205
t := time.NewTicker(w.options.HeartbeatInterval)
198206
defer t.Stop()
199207

@@ -204,6 +212,13 @@ func (w *Worker[Task, TaskResult]) heartbeatTask(ctx context.Context, task *Task
204212
case <-t.C:
205213
if err := w.tw.Extend(ctx, task); err != nil {
206214
w.logger.ErrorContext(ctx, "could not heartbeat task", "error", err)
215+
216+
// We might not own the task anymore, abort processing
217+
if cancel != nil {
218+
cancel()
219+
}
220+
221+
return
207222
}
208223
}
209224
}

internal/worker/worker_test.go

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ func TestWorker_Handle(t *testing.T) {
314314
task := &testTask{ID: 1, Data: "test"}
315315
result := &testResult{Output: "success"}
316316

317-
mockTaskWorker.On("Execute", ctx, task).Return(result, nil)
318-
mockTaskWorker.On("Complete", ctx, result, task).Return(nil)
317+
mockTaskWorker.On("Execute", mock.Anything, task).Return(result, nil)
318+
mockTaskWorker.On("Complete", mock.Anything, result, task).Return(nil)
319319

320320
err := worker.handle(ctx, task)
321321
assert.NoError(t, err)
@@ -339,8 +339,8 @@ func TestWorker_Handle(t *testing.T) {
339339
task := &testTask{ID: 1, Data: "test"}
340340
result := &testResult{Output: "success"}
341341

342-
mockTaskWorker.On("Execute", ctx, task).Return(result, nil)
343-
mockTaskWorker.On("Complete", ctx, result, task).Return(nil)
342+
mockTaskWorker.On("Execute", mock.Anything, task).Return(result, nil)
343+
mockTaskWorker.On("Complete", mock.Anything, result, task).Return(nil)
344344
// Heartbeat might be called during execution
345345
mockTaskWorker.On("Extend", mock.Anything, task).Return(nil).Maybe()
346346

@@ -365,7 +365,7 @@ func TestWorker_Handle(t *testing.T) {
365365
task := &testTask{ID: 1, Data: "test"}
366366
expectedErr := errors.New("execution error")
367367

368-
mockTaskWorker.On("Execute", ctx, task).Return(nil, expectedErr)
368+
mockTaskWorker.On("Execute", mock.Anything, task).Return(nil, expectedErr)
369369

370370
err := worker.handle(ctx, task)
371371
assert.Error(t, err)
@@ -391,15 +391,48 @@ func TestWorker_Handle(t *testing.T) {
391391
result := &testResult{Output: "success"}
392392
expectedErr := errors.New("completion error")
393393

394-
mockTaskWorker.On("Execute", ctx, task).Return(result, nil)
395-
mockTaskWorker.On("Complete", ctx, result, task).Return(expectedErr)
394+
mockTaskWorker.On("Execute", mock.Anything, task).Return(result, nil)
395+
mockTaskWorker.On("Complete", mock.Anything, result, task).Return(expectedErr)
396396

397397
err := worker.handle(ctx, task)
398398
assert.Error(t, err)
399399
assert.Equal(t, expectedErr, err)
400400

401401
mockTaskWorker.AssertExpectations(t)
402402
})
403+
404+
t.Run("abort processing on heartbeat extend failure", func(t *testing.T) {
405+
mockBackend := createMockBackend()
406+
mockTaskWorker := &mockTaskWorker{}
407+
408+
options := &WorkerOptions{
409+
Pollers: 1,
410+
MaxParallelTasks: 1,
411+
HeartbeatInterval: time.Millisecond * 5,
412+
}
413+
414+
worker := NewWorker(mockBackend, mockTaskWorker, options)
415+
416+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
417+
defer cancel()
418+
419+
task := &testTask{ID: 1, Data: "test"}
420+
421+
// Simulate Extend failing immediately so the heartbeat cancels processing
422+
mockTaskWorker.On("Extend", mock.Anything, task).Return(errors.New("extend failed")).Maybe()
423+
424+
// Execute should see canceled context and return context.Canceled or respect ctx.Done
425+
mockTaskWorker.On("Execute", mock.Anything, task).Return(nil, context.Canceled)
426+
427+
// Complete must NOT be called when execution is aborted due to lost ownership
428+
// No expectation set for Complete to ensure it's not invoked
429+
mockTaskWorker.AssertNotCalled(t, "Complete", mock.Anything, mock.Anything, mock.Anything)
430+
431+
err := worker.handle(ctx, task)
432+
require.Error(t, err)
433+
434+
mockTaskWorker.AssertExpectations(t)
435+
})
403436
}
404437

405438
func TestWorker_HeartbeatTask(t *testing.T) {
@@ -423,7 +456,7 @@ func TestWorker_HeartbeatTask(t *testing.T) {
423456
// Expect multiple heartbeat calls
424457
mockTaskWorker.On("Extend", ctx, task).Return(nil)
425458

426-
worker.heartbeatTask(ctx, task)
459+
worker.heartbeatTask(ctx, task, nil)
427460

428461
// Should have called Extend at least once
429462
mockTaskWorker.AssertExpectations(t)
@@ -450,7 +483,7 @@ func TestWorker_HeartbeatTask(t *testing.T) {
450483
mockTaskWorker.On("Extend", ctx, task).Return(expectedErr)
451484

452485
// Should not panic even with errors
453-
worker.heartbeatTask(ctx, task)
486+
worker.heartbeatTask(ctx, task, nil)
454487

455488
mockTaskWorker.AssertExpectations(t)
456489
})
@@ -475,7 +508,7 @@ func TestWorker_HeartbeatTask(t *testing.T) {
475508

476509
// Should exit quickly without calling Extend
477510
start := time.Now()
478-
worker.heartbeatTask(ctx, task)
511+
worker.heartbeatTask(ctx, task, nil)
479512
duration := time.Since(start)
480513

481514
assert.Less(t, duration, time.Millisecond*100)
@@ -504,6 +537,7 @@ func TestWorker_FullWorkflow(t *testing.T) {
504537
// Track processed tasks
505538
var processedTasks int32
506539
var taskResults []*testResult
540+
var mu sync.Mutex
507541

508542
task1 := &testTask{ID: 1, Data: "task1"}
509543
task2 := &testTask{ID: 2, Data: "task2"}
@@ -520,10 +554,14 @@ func TestWorker_FullWorkflow(t *testing.T) {
520554

521555
mockTaskWorker.On("Execute", mock.Anything, task1).Return(result1, nil).Run(func(args mock.Arguments) {
522556
atomic.AddInt32(&processedTasks, 1)
557+
mu.Lock()
558+
defer mu.Unlock()
523559
taskResults = append(taskResults, result1)
524560
})
525561
mockTaskWorker.On("Execute", mock.Anything, task2).Return(result2, nil).Run(func(args mock.Arguments) {
526562
atomic.AddInt32(&processedTasks, 1)
563+
mu.Lock()
564+
defer mu.Unlock()
527565
taskResults = append(taskResults, result2)
528566
})
529567

0 commit comments

Comments
 (0)