Skip to content

Commit 6aea589

Browse files
Boklazhenkoa.boklazhenko
andauthored
x-retry-attempt to StreamClientInterceptor (#733)
* x-retry-attempt to StreamClientInterceptor * unit test for StreamClientInterceptor AttemptMetadata --------- Co-authored-by: a.boklazhenko <[email protected]>
1 parent ba6f8b9 commit 6aea589

File tree

2 files changed

+84
-2
lines changed

2 files changed

+84
-2
lines changed

interceptors/retry/retry.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,15 @@ func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientIntercepto
100100
callOpts.onRetryCallback(parentCtx, attempt, lastErr)
101101
}
102102
var newStreamer grpc.ClientStream
103-
newStreamer, lastErr = streamer(parentCtx, desc, cc, method, grpcOpts...)
103+
newStreamer, lastErr = streamer(perStreamContext(parentCtx, callOpts, attempt), desc, cc, method, grpcOpts...)
104104
if lastErr == nil {
105105
retryingStreamer := &serverStreamingRetryingStream{
106106
ClientStream: newStreamer,
107107
callOpts: callOpts,
108108
parentCtx: parentCtx,
109109
streamerCall: func(ctx context.Context) (grpc.ClientStream, error) {
110-
return streamer(ctx, desc, cc, method, grpcOpts...)
110+
attempt++
111+
return streamer(perStreamContext(ctx, callOpts, attempt), desc, cc, method, grpcOpts...)
111112
},
112113
}
113114
return retryingStreamer, nil
@@ -296,6 +297,15 @@ func perCallContext(parentCtx context.Context, callOpts *options, attempt uint)
296297
return ctx, cancel
297298
}
298299

300+
func perStreamContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context {
301+
ctx := parentCtx
302+
if attempt > 0 && callOpts.includeHeader {
303+
mdClone := metadata.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt))
304+
ctx = mdClone.ToOutgoing(ctx)
305+
}
306+
return ctx
307+
}
308+
299309
func contextErrToGrpcErr(err error) error {
300310
switch err {
301311
case context.DeadlineExceeded:

interceptors/retry/retry_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package retry
66
import (
77
"context"
88
"io"
9+
"strconv"
910
"strings"
1011
"sync"
1112
"testing"
@@ -17,6 +18,7 @@ import (
1718
"github.com/stretchr/testify/suite"
1819
"google.golang.org/grpc"
1920
"google.golang.org/grpc/codes"
21+
"google.golang.org/grpc/metadata"
2022
"google.golang.org/grpc/status"
2123
)
2224

@@ -432,3 +434,73 @@ func TestJitterUp(t *testing.T) {
432434
assert.True(t, highCount != 0, "at least one sample should reach to >%s", high)
433435
assert.True(t, lowCount != 0, "at least one sample should to <%s", low)
434436
}
437+
438+
type failingClientStream struct {
439+
RecvMsgErr error
440+
}
441+
442+
func (s *failingClientStream) Header() (metadata.MD, error) {
443+
return nil, nil
444+
}
445+
446+
func (s *failingClientStream) Trailer() metadata.MD {
447+
return nil
448+
}
449+
450+
func (s *failingClientStream) CloseSend() error {
451+
return nil
452+
}
453+
454+
func (s *failingClientStream) Context() context.Context {
455+
return context.Background()
456+
}
457+
458+
func (s *failingClientStream) SendMsg(m any) error {
459+
return nil
460+
}
461+
462+
func (s *failingClientStream) RecvMsg(m any) error {
463+
return s.RecvMsgErr
464+
}
465+
466+
func TestStreamClientInterceptorAttemptMetadata(t *testing.T) {
467+
retryCount := 5
468+
attempt := 0
469+
recvMsgErr := status.Error(codes.Unavailable, "unavailable")
470+
471+
var testStreamer grpc.Streamer = func(
472+
ctx context.Context,
473+
desc *grpc.StreamDesc,
474+
cc *grpc.ClientConn,
475+
method string,
476+
opts ...grpc.CallOption,
477+
) (grpc.ClientStream, error) {
478+
if attempt > 0 {
479+
md, ok := metadata.FromOutgoingContext(ctx)
480+
require.True(t, ok)
481+
482+
raw := md.Get(AttemptMetadataKey)
483+
require.Len(t, raw, 1)
484+
485+
attemptMetadataValue, err := strconv.Atoi(raw[0])
486+
require.NoError(t, err)
487+
488+
require.Equal(t, attempt, attemptMetadataValue)
489+
}
490+
491+
attempt++
492+
493+
return &failingClientStream{
494+
RecvMsgErr: recvMsgErr,
495+
}, nil
496+
}
497+
498+
streamClientInterceptor := StreamClientInterceptor(WithCodes(codes.Unavailable), WithMax(uint(retryCount)))
499+
clientStream, err := streamClientInterceptor(context.Background(), &grpc.StreamDesc{}, nil, "some_method", testStreamer)
500+
require.NoError(t, err)
501+
502+
err = clientStream.RecvMsg(nil)
503+
require.ErrorIs(t, err, recvMsgErr)
504+
505+
require.Equal(t, retryCount, attempt)
506+
}

0 commit comments

Comments
 (0)