@@ -16,6 +16,7 @@ package kafka
1616import (
1717 "context"
1818 "errors"
19+ "fmt"
1920 "strconv"
2021 "sync/atomic"
2122 "testing"
@@ -26,6 +27,7 @@ import (
2627 "github.com/stretchr/testify/require"
2728
2829 "github.com/dapr/components-contrib/common/component/kafka/mocks"
30+ "github.com/dapr/components-contrib/pubsub"
2931 "github.com/dapr/kit/logger"
3032)
3133
@@ -234,6 +236,115 @@ func Test_reloadConsumerGroup(t *testing.T) {
234236 assert .Equal (t , int64 (2 ), cancelCalled .Load ())
235237 assert .Equal (t , int64 (2 ), consumeCalled .Load ())
236238 })
239+
240+ t .Run ("Cancel context whit shutdown error closes consumer group with one subscriber" , func (t * testing.T ) {
241+ var consumeCalled atomic.Int64
242+ var cancelCalled atomic.Int64
243+ var closeCalled atomic.Int64
244+ waitCh := make (chan struct {})
245+ cg := mocks .NewConsumerGroup ().
246+ WithConsumeFn (func (ctx context.Context , _ []string , _ sarama.ConsumerGroupHandler ) error {
247+ consumeCalled .Add (1 )
248+ <- ctx .Done ()
249+ cancelCalled .Add (1 )
250+ return nil
251+ }).WithCloseFn (func () error {
252+ closeCalled .Add (1 )
253+ waitCh <- struct {}{}
254+ return nil
255+ })
256+
257+ k := & Kafka {
258+ logger : logger .NewLogger ("test" ),
259+ mockConsumerGroup : cg ,
260+ consumerCancel : nil ,
261+ closeCh : make (chan struct {}),
262+ subscribeTopics : map [string ]SubscriptionHandlerConfig {"foo" : {}},
263+ consumeRetryInterval : time .Millisecond ,
264+ }
265+ c , err := k .latestClients ()
266+ require .NoError (t , err )
267+
268+ k .clients = c
269+ ctx , cancel := context .WithCancelCause (t .Context ())
270+ k .Subscribe (ctx , SubscriptionHandlerConfig {}, "foo" )
271+ assert .Eventually (t , func () bool {
272+ return consumeCalled .Load () == 1
273+ }, time .Second , time .Millisecond )
274+ assert .Equal (t , int64 (0 ), cancelCalled .Load ())
275+ cancel (pubsub .ErrGracefulShutdown )
276+ <- waitCh
277+ assert .Equal (t , int64 (1 ), closeCalled .Load ())
278+ })
279+
280+ t .Run ("Cancel context whit shutdown error closes consumer group with multiple subscriber" , func (t * testing.T ) {
281+ var closeCalled atomic.Int64
282+ waitCh := make (chan struct {})
283+ cg := mocks .NewConsumerGroup ().WithCloseFn (func () error {
284+ closeCalled .Add (1 )
285+ waitCh <- struct {}{}
286+ return nil
287+ })
288+
289+ k := & Kafka {
290+ logger : logger .NewLogger ("test" ),
291+ mockConsumerGroup : cg ,
292+ consumerCancel : nil ,
293+ closeCh : make (chan struct {}),
294+ subscribeTopics : map [string ]SubscriptionHandlerConfig {"foo" : {}},
295+ consumeRetryInterval : time .Millisecond ,
296+ }
297+ c , err := k .latestClients ()
298+ require .NoError (t , err )
299+
300+ k .clients = c
301+
302+ cancelFns := make ([]context.CancelCauseFunc , 0 , 100 )
303+ for i := range 100 {
304+ ctx , cancel := context .WithCancelCause (t .Context ())
305+ cancelFns = append (cancelFns , cancel )
306+ k .Subscribe (ctx , SubscriptionHandlerConfig {}, fmt .Sprintf ("foo%d" , i ))
307+ }
308+ cancelFns [0 ](pubsub .ErrGracefulShutdown )
309+ <- waitCh
310+ assert .Equal (t , int64 (1 ), closeCalled .Load ())
311+ })
312+
313+ t .Run ("Closing subscriptions with no error or no ErrGracefulShutdown does not close consumer group" , func (t * testing.T ) {
314+ var closeCalled atomic.Int64
315+ waitCh := make (chan struct {})
316+ cg := mocks .NewConsumerGroup ().WithCloseFn (func () error {
317+ closeCalled .Add (1 )
318+ waitCh <- struct {}{}
319+ return nil
320+ })
321+
322+ k := & Kafka {
323+ logger : logger .NewLogger ("test" ),
324+ mockConsumerGroup : cg ,
325+ consumerCancel : nil ,
326+ closeCh : make (chan struct {}),
327+ subscribeTopics : map [string ]SubscriptionHandlerConfig {"foo" : {}},
328+ consumeRetryInterval : time .Millisecond ,
329+ }
330+ c , err := k .latestClients ()
331+ require .NoError (t , err )
332+
333+ k .clients = c
334+ cancelFns := make ([]context.CancelCauseFunc , 0 , 100 )
335+ for i := range 100 {
336+ ctx , cancel := context .WithCancelCause (t .Context ())
337+ cancelFns = append (cancelFns , cancel )
338+ k .Subscribe (ctx , SubscriptionHandlerConfig {}, fmt .Sprintf ("foo%d" , i ))
339+ }
340+ cancelFns [0 ](errors .New ("some error" ))
341+ time .Sleep (1 * time .Second )
342+ assert .Equal (t , int64 (0 ), closeCalled .Load ())
343+
344+ cancelFns [4 ](nil )
345+ time .Sleep (1 * time .Second )
346+ assert .Equal (t , int64 (0 ), closeCalled .Load ())
347+ })
237348}
238349
239350func Test_Subscribe (t * testing.T ) {
0 commit comments