Skip to content

Commit 4f8cd8b

Browse files
fix: Fix kafka consumer shutdown (#3907)
Signed-off-by: Javier Aliaga <[email protected]> Co-authored-by: Yaron Schneider <[email protected]>
1 parent eae3312 commit 4f8cd8b

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

common/component/kafka/subscriber.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import (
1717
"context"
1818
"errors"
1919
"time"
20+
21+
"github.com/dapr/components-contrib/pubsub"
2022
)
2123

2224
// Subscribe adds a handler and configuration for a topic, and subscribes.
@@ -35,9 +37,26 @@ func (k *Kafka) Subscribe(ctx context.Context, handlerConfig SubscriptionHandler
3537
k.wg.Add(1)
3638
go func() {
3739
defer k.wg.Done()
40+
postAction := func() {}
41+
3842
select {
3943
case <-ctx.Done():
44+
err := context.Cause(ctx)
45+
if errors.Is(err, pubsub.ErrGracefulShutdown) {
46+
k.logger.Debugf("Kafka component is closing. Context is done due to shutdown process.")
47+
postAction = func() {
48+
if k.clients != nil && k.clients.consumerGroup != nil {
49+
k.logger.Debugf("Kafka component is closing. Closing consumer group.")
50+
err := k.clients.consumerGroup.Close()
51+
if err != nil {
52+
k.logger.Errorf("failed to close consumer group: %w", err)
53+
}
54+
}
55+
}
56+
}
57+
4058
case <-k.closeCh:
59+
k.logger.Debugf("Kafka component is closing. Channel is closed.")
4160
}
4261

4362
k.subscribeLock.Lock()
@@ -50,6 +69,7 @@ func (k *Kafka) Subscribe(ctx context.Context, handlerConfig SubscriptionHandler
5069
}
5170

5271
k.reloadConsumerGroup()
72+
postAction()
5373
}()
5474
}
5575

@@ -87,9 +107,11 @@ func (k *Kafka) consume(ctx context.Context, topics []string, consumer *consumer
87107
clients, err := k.latestClients()
88108
if err != nil || clients == nil {
89109
k.logger.Errorf("failed to get latest Kafka clients: %w", err)
110+
return
90111
}
91112
if clients.consumerGroup == nil {
92113
k.logger.Errorf("component is closed")
114+
return
93115
}
94116
err = clients.consumerGroup.Consume(ctx, topics, consumer)
95117
if errors.Is(err, context.Canceled) {

common/component/kafka/subscriber_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package kafka
1616
import (
1717
"context"
1818
"errors"
19+
"fmt"
1920
"strconv"
2021
"sync/atomic"
2122
"testing"
@@ -26,6 +27,7 @@ import (
2627
"github.com/stretchr/testify/require"
2728

2829
"github.com/dapr/components-contrib/common/component/kafka/mocks"
30+
"github.com/dapr/components-contrib/pubsub"
2931
"github.com/dapr/kit/logger"
3032
)
3133

@@ -234,6 +236,115 @@ func Test_reloadConsumerGroup(t *testing.T) {
234236
assert.Equal(t, int64(2), cancelCalled.Load())
235237
assert.Equal(t, int64(2), consumeCalled.Load())
236238
})
239+
240+
t.Run("Cancel context whit shutdown error closes consumer group with one subscriber", func(t *testing.T) {
241+
var consumeCalled atomic.Int64
242+
var cancelCalled atomic.Int64
243+
var closeCalled atomic.Int64
244+
waitCh := make(chan struct{})
245+
cg := mocks.NewConsumerGroup().
246+
WithConsumeFn(func(ctx context.Context, _ []string, _ sarama.ConsumerGroupHandler) error {
247+
consumeCalled.Add(1)
248+
<-ctx.Done()
249+
cancelCalled.Add(1)
250+
return nil
251+
}).WithCloseFn(func() error {
252+
closeCalled.Add(1)
253+
waitCh <- struct{}{}
254+
return nil
255+
})
256+
257+
k := &Kafka{
258+
logger: logger.NewLogger("test"),
259+
mockConsumerGroup: cg,
260+
consumerCancel: nil,
261+
closeCh: make(chan struct{}),
262+
subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}},
263+
consumeRetryInterval: time.Millisecond,
264+
}
265+
c, err := k.latestClients()
266+
require.NoError(t, err)
267+
268+
k.clients = c
269+
ctx, cancel := context.WithCancelCause(t.Context())
270+
k.Subscribe(ctx, SubscriptionHandlerConfig{}, "foo")
271+
assert.Eventually(t, func() bool {
272+
return consumeCalled.Load() == 1
273+
}, time.Second, time.Millisecond)
274+
assert.Equal(t, int64(0), cancelCalled.Load())
275+
cancel(pubsub.ErrGracefulShutdown)
276+
<-waitCh
277+
assert.Equal(t, int64(1), closeCalled.Load())
278+
})
279+
280+
t.Run("Cancel context whit shutdown error closes consumer group with multiple subscriber", func(t *testing.T) {
281+
var closeCalled atomic.Int64
282+
waitCh := make(chan struct{})
283+
cg := mocks.NewConsumerGroup().WithCloseFn(func() error {
284+
closeCalled.Add(1)
285+
waitCh <- struct{}{}
286+
return nil
287+
})
288+
289+
k := &Kafka{
290+
logger: logger.NewLogger("test"),
291+
mockConsumerGroup: cg,
292+
consumerCancel: nil,
293+
closeCh: make(chan struct{}),
294+
subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}},
295+
consumeRetryInterval: time.Millisecond,
296+
}
297+
c, err := k.latestClients()
298+
require.NoError(t, err)
299+
300+
k.clients = c
301+
302+
cancelFns := make([]context.CancelCauseFunc, 0, 100)
303+
for i := range 100 {
304+
ctx, cancel := context.WithCancelCause(t.Context())
305+
cancelFns = append(cancelFns, cancel)
306+
k.Subscribe(ctx, SubscriptionHandlerConfig{}, fmt.Sprintf("foo%d", i))
307+
}
308+
cancelFns[0](pubsub.ErrGracefulShutdown)
309+
<-waitCh
310+
assert.Equal(t, int64(1), closeCalled.Load())
311+
})
312+
313+
t.Run("Closing subscriptions with no error or no ErrGracefulShutdown does not close consumer group", func(t *testing.T) {
314+
var closeCalled atomic.Int64
315+
waitCh := make(chan struct{})
316+
cg := mocks.NewConsumerGroup().WithCloseFn(func() error {
317+
closeCalled.Add(1)
318+
waitCh <- struct{}{}
319+
return nil
320+
})
321+
322+
k := &Kafka{
323+
logger: logger.NewLogger("test"),
324+
mockConsumerGroup: cg,
325+
consumerCancel: nil,
326+
closeCh: make(chan struct{}),
327+
subscribeTopics: map[string]SubscriptionHandlerConfig{"foo": {}},
328+
consumeRetryInterval: time.Millisecond,
329+
}
330+
c, err := k.latestClients()
331+
require.NoError(t, err)
332+
333+
k.clients = c
334+
cancelFns := make([]context.CancelCauseFunc, 0, 100)
335+
for i := range 100 {
336+
ctx, cancel := context.WithCancelCause(t.Context())
337+
cancelFns = append(cancelFns, cancel)
338+
k.Subscribe(ctx, SubscriptionHandlerConfig{}, fmt.Sprintf("foo%d", i))
339+
}
340+
cancelFns[0](errors.New("some error"))
341+
time.Sleep(1 * time.Second)
342+
assert.Equal(t, int64(0), closeCalled.Load())
343+
344+
cancelFns[4](nil)
345+
time.Sleep(1 * time.Second)
346+
assert.Equal(t, int64(0), closeCalled.Load())
347+
})
237348
}
238349

239350
func Test_Subscribe(t *testing.T) {

pubsub/pubsub.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"github.com/dapr/components-contrib/metadata"
2323
)
2424

25+
var ErrGracefulShutdown = errors.New("pubsub shutdown")
26+
2527
// PubSub is the interface for message buses.
2628
type PubSub interface {
2729
metadata.ComponentWithMetadata

0 commit comments

Comments
 (0)