Skip to content

Commit 6cd9f2e

Browse files
committed
Use proper context deadline to cancel requests
This enables a request to be canceled when it's context is canceled. Perhaps due to an external timeout. This also removes the chanSendWaitTime in favor of a context cancel. When a user sets a shorter deadline with a cause on the context, that cause will be returned by Send when the method returns.
1 parent 8582edf commit 6cd9f2e

File tree

5 files changed

+187
-56
lines changed

5 files changed

+187
-56
lines changed

client.go

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,6 @@ import (
1313
amqp "github.com/rabbitmq/amqp091-go"
1414
)
1515

16-
const (
17-
// chanSendWaitTime is the maximum time we will wait when sending a
18-
// response, confirm or error on the corresponding channels. This is so that
19-
// we won't block forever if the listening goroutine has stopped listening.
20-
chanSendWaitTime = 10 * time.Second
21-
)
22-
2316
var (
2417
// ErrRequestReturned can be returned by Client#Send() when the server
2518
// returns the message. For example when mandatory is set but the message
@@ -442,7 +435,7 @@ func (c *Client) runPublisher(ouputChan *amqp.Channel) {
442435
request.Publishing.ReplyTo = ""
443436
}
444437

445-
c.logger.Debug("publishing request",
438+
c.logger.DebugContext(request.Context, "publishing request",
446439
slog.String("correlation_id", request.Publishing.CorrelationId),
447440
)
448441

@@ -467,7 +460,7 @@ func (c *Client) runPublisher(ouputChan *amqp.Channel) {
467460

468461
c.retryRequest(request, err)
469462

470-
c.logger.Error(
463+
c.logger.ErrorContext(request.Context,
471464
"publisher stopped because of error",
472465
slog.Any("error", err),
473466
slogGroupFor("request", slogAttrsForRequest(request)),
@@ -498,7 +491,7 @@ func (c *Client) runPublisher(ouputChan *amqp.Channel) {
498491
func (c *Client) retryRequest(request *Request, err error) {
499492
if request.numRetries >= c.maxRetries {
500493
// We have already retried too many times
501-
c.logger.Error(
494+
c.logger.ErrorContext(request.Context,
502495
"could not publish, giving up",
503496
slog.Any("error", err),
504497
slogGroupFor("request", slogAttrsForRequest(request)),
@@ -517,16 +510,16 @@ func (c *Client) retryRequest(request *Request, err error) {
517510
request.numRetries++
518511

519512
go func() {
520-
c.logger.Debug("queuing request for retry",
513+
c.logger.DebugContext(request.Context, "queuing request for retry",
521514
slog.Any("error", err),
522515
slogGroupFor("request", slogAttrsForRequest(request)),
523516
)
524517

525518
select {
526519
case c.requests <- request:
527-
case <-request.AfterTimeout():
528-
c.logger.Debug("request timed out while waiting for retry",
529-
slog.Any("error", err),
520+
case <-request.Context.Done():
521+
c.logger.ErrorContext(request.Context, "canceled while waiting for retry",
522+
slog.Any("error", errors.Join(context.Cause(request.Context), err)),
530523
slogGroupFor("request", slogAttrsForRequest(request)),
531524
)
532525
}
@@ -561,7 +554,7 @@ func (c *Client) runConfirmsConsumer(confirms chan amqp.Confirmation, returns ch
561554
continue
562555
}
563556

564-
c.logger.Debug("publishing is returned by server",
557+
c.logger.DebugContext(request.Context, "publishing is returned by server",
565558
slog.String("correlation_id", ret.CorrelationId),
566559
)
567560

@@ -587,7 +580,7 @@ func (c *Client) runConfirmsConsumer(confirms chan amqp.Confirmation, returns ch
587580
continue
588581
}
589582

590-
c.logger.Debug("confirming request",
583+
c.logger.DebugContext(request.Context, "confirming request",
591584
slog.String("correlation_id", request.Publishing.CorrelationId),
592585
)
593586

@@ -629,8 +622,8 @@ func (c *Client) respondToRequest(request *Request, response *amqp.Delivery) {
629622
select {
630623
case request.response <- response:
631624
return
632-
case <-time.After(chanSendWaitTime):
633-
c.logger.Error(
625+
case <-request.Context.Done():
626+
c.logger.ErrorContext(request.Context,
634627
"nobody is waiting for response",
635628
slogGroupFor("request", slogAttrsForRequest(request)),
636629
slogGroupFor("delivery", slogAttrsForDelivery(response)),
@@ -643,8 +636,8 @@ func (c *Client) respondErrorToRequest(request *Request, err error) {
643636
select {
644637
case request.errChan <- err:
645638
return
646-
case <-time.After(chanSendWaitTime):
647-
c.logger.Error(
639+
case <-request.Context.Done():
640+
c.logger.ErrorContext(request.Context,
648641
"nobody is waiting for error",
649642
slog.Any("error", err),
650643
slogGroupFor("request", slogAttrsForRequest(request)),
@@ -658,8 +651,8 @@ func (c *Client) confirmRequest(request *Request) {
658651
select {
659652
case request.confirmed <- struct{}{}:
660653
return
661-
case <-time.After(chanSendWaitTime):
662-
c.logger.Error(
654+
case <-request.Context.Done():
655+
c.logger.ErrorContext(request.Context,
663656
"nobody is waiting for confirmation",
664657
slogGroupFor("request", slogAttrsForRequest(request)),
665658
)
@@ -716,16 +709,16 @@ func (c *Client) runRepliesConsumer(inChan *amqp.Channel) error {
716709
continue
717710
}
718711

719-
c.logger.Debug("forwarding reply",
712+
c.logger.DebugContext(request.Context, "forwarding reply",
720713
slog.String("correlation_id", response.CorrelationId),
721714
)
722715

723716
responseCopy := response
724717

725718
select {
726719
case request.response <- &responseCopy:
727-
case <-time.After(chanSendWaitTime):
728-
c.logger.Error(
720+
case <-request.Context.Done():
721+
c.logger.ErrorContext(request.Context,
729722
"nobody is waiting on response on request",
730723
slogGroupFor("request", slogAttrsForRequest(request)),
731724
)
@@ -773,23 +766,28 @@ func (c *Client) send(r *Request) (*amqp.Delivery, error) {
773766

774767
defer c.requestsMap.Delete(r)
775768

776-
r.startTimeout(c.timeout)
777-
timeoutChan := r.AfterTimeout()
769+
cancel := r.startTimeout(c.timeout)
770+
defer cancel()
778771

779772
c.logger.Debug("queuing request", slog.String("correlation_id", r.Publishing.CorrelationId))
780773

781774
select {
782775
case c.requests <- r:
783776
// successful send.
784-
case <-timeoutChan:
785-
c.logger.Debug("timeout while waiting for request queue %s",
777+
case <-r.Context.Done():
778+
err := context.Cause(r.Context)
779+
780+
c.logger.DebugContext(r.Context,
781+
"canceled while waiting for request queue",
782+
slog.Any("error", err),
786783
slog.String("correlation_id", r.Publishing.CorrelationId),
787784
)
788785

789-
return nil, fmt.Errorf("%w while waiting for request queue", ErrRequestTimeout)
786+
return nil, fmt.Errorf("%w while waiting for request queue", err)
790787
}
791788

792-
c.logger.Debug("waiting for reply",
789+
c.logger.DebugContext(r.Context,
790+
"waiting for reply",
793791
slog.String("correlation_id", r.Publishing.CorrelationId),
794792
)
795793

@@ -798,34 +796,44 @@ func (c *Client) send(r *Request) (*amqp.Delivery, error) {
798796
select {
799797
case <-r.confirmed:
800798
// got confirmation.
801-
case <-timeoutChan:
802-
c.logger.Debug("timeout while waiting for request confirmation",
799+
case <-r.Context.Done():
800+
err := context.Cause(r.Context)
801+
802+
c.logger.DebugContext(r.Context,
803+
"canceled while waiting for request confirmation",
804+
slog.Any("error", err),
803805
slog.String("correlation_id", r.Publishing.CorrelationId),
804806
)
805807

806-
return nil, fmt.Errorf("%w while waiting for confirmation", ErrRequestTimeout)
808+
return nil, fmt.Errorf("%w while waiting for confirmation", err)
807809
}
808810

809811
// All responses are published on the requests response channel. Hang here
810812
// until a response is received and close the channel when it's read.
811813
select {
812814
case err := <-r.errChan:
813-
c.logger.Debug("error for request",
815+
c.logger.DebugContext(r.Context,
816+
"error for request",
814817
slog.Any("error", err),
815818
slog.String("correlation_id", r.Publishing.CorrelationId),
816819
)
817820

818821
return nil, err
819822

820-
case <-timeoutChan:
821-
c.logger.Debug("timeout for request",
823+
case <-r.Context.Done():
824+
err := context.Cause(r.Context)
825+
826+
c.logger.DebugContext(r.Context,
827+
"canceled while waiting for response",
828+
slog.Any("error", err),
822829
slog.String("correlation_id", r.Publishing.CorrelationId),
823830
)
824831

825-
return nil, fmt.Errorf("%w while waiting for response", ErrRequestTimeout)
832+
return nil, fmt.Errorf("%w while waiting for response", err)
826833

827834
case delivery := <-r.response:
828-
c.logger.Debug("got response",
835+
c.logger.DebugContext(r.Context,
836+
"got response",
829837
slog.String("correlation_id", r.Publishing.CorrelationId),
830838
)
831839

connection.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,27 +30,36 @@ func monitorAndWait(
3030
select {
3131
case <-restartChan:
3232
return true, nil
33+
3334
case <-stopChan:
3435
return false, nil
36+
3537
case err, ok := <-inputConnClose:
3638
if !ok {
3739
return false, ErrUnexpectedConnClosed
3840
}
41+
3942
return false, err
43+
4044
case err, ok := <-outputConnClose:
4145
if !ok {
4246
return false, ErrUnexpectedConnClosed
4347
}
48+
4449
return false, err
50+
4551
case err, ok := <-inputChClose:
4652
if !ok {
4753
return false, ErrUnexpectedConnClosed
4854
}
55+
4956
return false, err
57+
5058
case err, ok := <-outputChClose:
5159
if !ok {
5260
return false, ErrUnexpectedConnClosed
5361
}
62+
5463
return false, err
5564
}
5665
}

logging.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,10 @@ func slogAttrsForRequest(v *Request) []slog.Attr {
185185
vals = append(vals, slog.Bool("mandatory", v.Mandatory))
186186
}
187187

188-
if v.Timeout > 0 {
189-
vals = append(vals, slog.Duration("timeout", v.Timeout))
188+
if v.Context != nil {
189+
if deadline, ok := v.Context.Deadline(); ok {
190+
vals = append(vals, slog.Time("deadline", deadline))
191+
}
190192
}
191193

192194
return vals

request.go

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package amqprpc
22

33
import (
44
"context"
5-
"fmt"
5+
"strconv"
66
"sync"
77
"time"
88

@@ -32,10 +32,6 @@ type Request struct {
3232
// we assume the request got lost.
3333
Timeout time.Duration
3434

35-
// timeoutAt is the exact time when the request times out. This is set by
36-
// the client when starting the countdown.
37-
timeoutAt time.Time
38-
3935
// Publishing is the publising that are going to be published.
4036
Publishing amqp.Publishing
4137

@@ -184,21 +180,39 @@ func (r *Request) AddMiddleware(m ClientMiddlewareFunc) *Request {
184180
// startTimeout will start the timeout counter. Is will also set the Expiration
185181
// field for the Publishing so that amqp won't hold on to the message in the
186182
// queue after the timeout has happened.
187-
func (r *Request) startTimeout(defaultTimeout time.Duration) {
188-
if r.Timeout.Nanoseconds() == 0 {
189-
r.WithTimeout(defaultTimeout)
183+
func (r *Request) startTimeout(defaultTimeout time.Duration) context.CancelFunc {
184+
if r.Context == nil {
185+
r.Context = context.Background()
190186
}
191187

192-
if r.Reply {
193-
r.Publishing.Expiration = fmt.Sprintf("%d", r.Timeout.Nanoseconds()/1e6)
188+
timeout := r.Timeout
189+
if timeout == 0 {
190+
timeout = defaultTimeout
194191
}
195192

196-
r.timeoutAt = time.Now().Add(r.Timeout)
197-
}
193+
var cancel context.CancelFunc
194+
195+
if timeout > 0 {
196+
r.Context, cancel = context.WithTimeoutCause(r.Context, timeout, ErrRequestTimeout)
197+
} else {
198+
r.Context, cancel = context.WithCancel(r.Context)
199+
}
200+
201+
if r.Reply {
202+
if deadline, ok := r.Context.Deadline(); ok {
203+
// When a request requires a reply, there is no point in executing the
204+
// request if the client has stopped waiting.
205+
// We make sure that we round up. 1001μs should be rounded to 2ms.
206+
// The expiration is the number of milliseconds until the message
207+
// is expired, counted from when it arrives in the queue. This is
208+
// always later than now so the message will expire a little later
209+
// than the deadline.
210+
etaMs := (time.Until(deadline) + time.Millisecond).Milliseconds()
211+
r.Publishing.Expiration = strconv.FormatInt(etaMs, 10)
212+
}
213+
}
198214

199-
// AfterTimeout waits for the duration of the timeout.
200-
func (r *Request) AfterTimeout() <-chan time.Time {
201-
return time.After(time.Until(r.timeoutAt))
215+
return cancel
202216
}
203217

204218
// RequestMap keeps track of requests based on their DeliveryTag and/or

0 commit comments

Comments
 (0)