Skip to content

Commit c37d9a5

Browse files
committed
Client gracefully stops
This ensures that when the client wants to stop it will ensure that all goroutines have stopped.
1 parent 9107c50 commit c37d9a5

File tree

2 files changed

+109
-46
lines changed

2 files changed

+109
-46
lines changed

client.go

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"fmt"
88
"log/slog"
9+
"sync"
910
"sync/atomic"
1011
"time"
1112

@@ -361,7 +362,11 @@ func (c *Client) runOnce() error {
361362
notifyInputChClose := inputCh.NotifyClose(make(chan *amqp.Error, 1))
362363
notifyOutputChClose := outputCh.NotifyClose(make(chan *amqp.Error, 1))
363364

364-
err = c.runRepliesConsumer(inputCh)
365+
// We will wait on this to ensure that all go routines are done before we
366+
// exit this function.
367+
wg := sync.WaitGroup{}
368+
369+
repliesConsumerTag, err := c.runRepliesConsumer(inputCh, &wg)
365370
if err != nil {
366371
return err
367372
}
@@ -376,13 +381,18 @@ func (c *Client) runOnce() error {
376381
return err
377382
}
378383

384+
wg.Add(1) // Confirms consumer.
385+
379386
go c.runConfirmsConsumer(
380387
outputCh.NotifyPublish(make(chan amqp.Confirmation)),
381388
outputCh.NotifyReturn(make(chan amqp.Return)),
389+
&wg,
382390
)
383391
}
384392

385-
go c.runPublisher(outputCh)
393+
wg.Add(1) // Publisher.
394+
395+
go c.runPublisher(outputCh, &wg)
386396

387397
_, err = monitorAndWait(
388398
make(chan struct{}),
@@ -393,28 +403,39 @@ func (c *Client) runOnce() error {
393403
notifyOutputChClose,
394404
)
395405
if err != nil {
406+
// We don't have a graceful exit, just return the error.
396407
return err
397408
}
398409

410+
// 1. Stop the publisher by closing the output channel. This also closes
411+
// the confirms consumer if it's running.
412+
outputCh.Close()
413+
414+
// 2. Stop the replies consumer by canceling the consumer.
415+
err = inputCh.Cancel(repliesConsumerTag, false)
416+
if err != nil {
417+
return err
418+
}
419+
420+
// 3. The consumer is stopped, we can now close the input channel.
421+
inputCh.Close()
422+
423+
// 3. Wait for all the go routines to finish.
424+
wg.Wait()
425+
399426
return nil
400427
}
401428

402429
// runPublisher consumes messages from chan requests and publishes them on the
403430
// amqp exchange. The method will stop consuming if the underlying amqp channel
404431
// is closed for any reason, and when this happens the messages will be put back
405432
// in chan requests unless we have retried to many times.
406-
func (c *Client) runPublisher(ouputChan *amqp.Channel) {
407-
c.logger.Debug("running publisher...")
433+
func (c *Client) runPublisher(ouputChan *amqp.Channel, wg *sync.WaitGroup) {
434+
defer wg.Done()
408435

409-
// Monitor the closing of this channel. We need to do this in a separate,
410-
// goroutine to ensure we won't get a deadlock inside the select below
411-
// which can itself close this channel.
412-
onClose := make(chan struct{})
436+
onClose := ouputChan.NotifyClose(make(chan *amqp.Error, 1))
413437

414-
go func() {
415-
<-ouputChan.NotifyClose(make(chan *amqp.Error))
416-
close(onClose)
417-
}()
438+
c.logger.Debug("running publisher...")
418439

419440
// Delivery tags always starts at 1 but we increase it before we do any
420441
// .Publish() on the channel.
@@ -456,6 +477,9 @@ func (c *Client) runPublisher(ouputChan *amqp.Channel) {
456477
request.Publishing,
457478
)
458479
if err != nil {
480+
// Normally a Publish that results in an error will
481+
// automatically close the channel and connection. But if the
482+
// error occurs during a flush, that doesn't happen.
459483
ouputChan.Close()
460484

461485
c.retryRequest(request, err)
@@ -518,7 +542,9 @@ func (c *Client) retryRequest(request *Request, err error) {
518542
// runConfirmsConsumer will consume both confirmations and returns and since
519543
// returns always arrives before confirmations we want to finish handling any
520544
// return before we handle any confirmations.
521-
func (c *Client) runConfirmsConsumer(confirms chan amqp.Confirmation, returns chan amqp.Return) {
545+
func (c *Client) runConfirmsConsumer(confirms chan amqp.Confirmation, returns chan amqp.Return, wg *sync.WaitGroup) {
546+
defer wg.Done()
547+
522548
for {
523549
select {
524550
case ret, ok := <-returns:
@@ -623,7 +649,7 @@ func (c *Client) respondToRequest(request *Request, d *amqp.Delivery, err error)
623649
// runRepliesConsumer will declare and start consuming from the queue where we
624650
// expect replies to come back. The method will stop consuming if the
625651
// underlying amqp channel is closed for any reason.
626-
func (c *Client) runRepliesConsumer(inChan *amqp.Channel) error {
652+
func (c *Client) runRepliesConsumer(inChan *amqp.Channel, wg *sync.WaitGroup) (consumerTag string, err error) {
627653
// RabbitMQ will soon no longer support what they call "non-exclusive
628654
// transient queues". We want to support reconnects and so we cannot set
629655
// the exclusive flag since that would delete the queue on automatically on disconnect.
@@ -637,23 +663,29 @@ func (c *Client) runRepliesConsumer(inChan *amqp.Channel) error {
637663
c.replyToQueueDeclareArgs,
638664
)
639665
if err != nil {
640-
return err
666+
return "", err
641667
}
642668

669+
tag := uuid.NewString()
670+
643671
messages, err := inChan.Consume(
644672
queue.Name,
645-
"", // consumer tag. Auto-generated by the server.
673+
tag, // consumer tag.
646674
true, // auto-ack. We don't support manual ack for the reply-to queue.
647675
true, // exclusive. We must be the only consumer.
648676
false, // no-local.
649677
false, // no-wait.
650678
c.replyToConsumerArgs,
651679
)
652680
if err != nil {
653-
return err
681+
return "", err
654682
}
655683

684+
wg.Add(1)
685+
656686
go func() {
687+
defer wg.Done()
688+
657689
c.logger.Debug("running replies consumer...")
658690

659691
for response := range messages {
@@ -680,7 +712,7 @@ func (c *Client) runRepliesConsumer(inChan *amqp.Channel) error {
680712
c.logger.Debug("replies consumer is done")
681713
}()
682714

683-
return nil
715+
return tag, nil
684716
}
685717

686718
// Send will send a Request by using a amqp.Publishing.

client_test.go

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ func TestClient_ConfirmsConsumer(t *testing.T) {
118118
returns := make(chan amqp.Return)
119119
confirms := make(chan amqp.Confirmation)
120120

121-
go client.runConfirmsConsumer(confirms, returns)
121+
wg := sync.WaitGroup{}
122+
wg.Add(1)
123+
124+
go client.runConfirmsConsumer(confirms, returns, &wg)
122125

123126
t.Cleanup(func() {
124127
close(confirms)
@@ -246,12 +249,11 @@ func TestClient_ConfirmsConsumer(t *testing.T) {
246249
t.Run("closing returns will not stop select", func(t *testing.T) {
247250
confirms := make(chan amqp.Confirmation)
248251
returns := make(chan amqp.Return)
249-
finished := make(chan struct{})
250252

251-
go func() {
252-
client.runConfirmsConsumer(confirms, returns)
253-
close(finished)
254-
}()
253+
wg := sync.WaitGroup{}
254+
wg.Add(1)
255+
256+
go client.runConfirmsConsumer(confirms, returns, &wg)
255257

256258
close(returns)
257259

@@ -272,6 +274,13 @@ func TestClient_ConfirmsConsumer(t *testing.T) {
272274

273275
close(confirms)
274276

277+
finished := make(chan struct{})
278+
279+
go func() {
280+
wg.Wait()
281+
close(finished)
282+
}()
283+
275284
select {
276285
case <-finished:
277286
case <-time.After(5 * time.Second):
@@ -280,7 +289,35 @@ func TestClient_ConfirmsConsumer(t *testing.T) {
280289
})
281290
}
282291

292+
func TestClientStop(t *testing.T) {
293+
t.Parallel()
294+
295+
client := NewClient(testURL)
296+
297+
request := NewRequest().
298+
WithResponse(false)
299+
300+
// Ensure that the client has started.
301+
_, err := client.Send(request)
302+
require.ErrorIs(t, err, ErrRequestReturned)
303+
304+
stopped := make(chan struct{})
305+
306+
go func() {
307+
client.Stop()
308+
close(stopped)
309+
}()
310+
311+
select {
312+
case <-stopped:
313+
case <-time.After(5 * time.Second):
314+
t.Fatal("did not exit")
315+
}
316+
}
317+
283318
func TestClientStopWhenCannotStart(t *testing.T) {
319+
t.Parallel()
320+
284321
client := NewClient(testURL)
285322

286323
request := NewRequest().
@@ -290,43 +327,37 @@ func TestClientStopWhenCannotStart(t *testing.T) {
290327
_, err := client.Send(request)
291328
require.Error(t, err)
292329

293-
var stopped sync.WaitGroup
294-
295-
stopped.Add(1)
330+
stopped := make(chan struct{})
296331

297332
go func() {
298333
client.Stop()
299-
stopped.Done()
334+
close(stopped)
300335
}()
301336

302-
assert.Eventually(t, func() bool {
303-
stopped.Wait()
304-
return true
305-
},
306-
1*time.Second,
307-
500*time.Millisecond,
308-
)
337+
select {
338+
case <-stopped:
339+
case <-time.After(5 * time.Second):
340+
t.Fatal("did not exit")
341+
}
309342
}
310343

311344
func TestClientStopWhenNeverStarted(t *testing.T) {
312-
client := NewClient(testURL)
345+
t.Parallel()
313346

314-
var stopped sync.WaitGroup
347+
client := NewClient(testURL)
315348

316-
stopped.Add(1)
349+
stopped := make(chan struct{})
317350

318351
go func() {
319352
client.Stop()
320-
stopped.Done()
353+
close(stopped)
321354
}()
322355

323-
assert.Eventually(t, func() bool {
324-
stopped.Wait()
325-
return true
326-
},
327-
1*time.Second,
328-
500*time.Millisecond,
329-
)
356+
select {
357+
case <-stopped:
358+
case <-time.After(5 * time.Second):
359+
t.Fatal("did not exit")
360+
}
330361
}
331362

332363
func TestClientConfig(t *testing.T) {

0 commit comments

Comments
 (0)