@@ -6,6 +6,7 @@ package retry
6
6
import (
7
7
"context"
8
8
"io"
9
+ "strconv"
9
10
"strings"
10
11
"sync"
11
12
"testing"
@@ -17,6 +18,7 @@ import (
17
18
"github.com/stretchr/testify/suite"
18
19
"google.golang.org/grpc"
19
20
"google.golang.org/grpc/codes"
21
+ "google.golang.org/grpc/metadata"
20
22
"google.golang.org/grpc/status"
21
23
)
22
24
@@ -432,3 +434,73 @@ func TestJitterUp(t *testing.T) {
432
434
assert .True (t , highCount != 0 , "at least one sample should reach to >%s" , high )
433
435
assert .True (t , lowCount != 0 , "at least one sample should to <%s" , low )
434
436
}
437
+
438
+ type failingClientStream struct {
439
+ RecvMsgErr error
440
+ }
441
+
442
+ func (s * failingClientStream ) Header () (metadata.MD , error ) {
443
+ return nil , nil
444
+ }
445
+
446
+ func (s * failingClientStream ) Trailer () metadata.MD {
447
+ return nil
448
+ }
449
+
450
+ func (s * failingClientStream ) CloseSend () error {
451
+ return nil
452
+ }
453
+
454
+ func (s * failingClientStream ) Context () context.Context {
455
+ return context .Background ()
456
+ }
457
+
458
+ func (s * failingClientStream ) SendMsg (m any ) error {
459
+ return nil
460
+ }
461
+
462
+ func (s * failingClientStream ) RecvMsg (m any ) error {
463
+ return s .RecvMsgErr
464
+ }
465
+
466
+ func TestStreamClientInterceptorAttemptMetadata (t * testing.T ) {
467
+ retryCount := 5
468
+ attempt := 0
469
+ recvMsgErr := status .Error (codes .Unavailable , "unavailable" )
470
+
471
+ var testStreamer grpc.Streamer = func (
472
+ ctx context.Context ,
473
+ desc * grpc.StreamDesc ,
474
+ cc * grpc.ClientConn ,
475
+ method string ,
476
+ opts ... grpc.CallOption ,
477
+ ) (grpc.ClientStream , error ) {
478
+ if attempt > 0 {
479
+ md , ok := metadata .FromOutgoingContext (ctx )
480
+ require .True (t , ok )
481
+
482
+ raw := md .Get (AttemptMetadataKey )
483
+ require .Len (t , raw , 1 )
484
+
485
+ attemptMetadataValue , err := strconv .Atoi (raw [0 ])
486
+ require .NoError (t , err )
487
+
488
+ require .Equal (t , attempt , attemptMetadataValue )
489
+ }
490
+
491
+ attempt ++
492
+
493
+ return & failingClientStream {
494
+ RecvMsgErr : recvMsgErr ,
495
+ }, nil
496
+ }
497
+
498
+ streamClientInterceptor := StreamClientInterceptor (WithCodes (codes .Unavailable ), WithMax (uint (retryCount )))
499
+ clientStream , err := streamClientInterceptor (context .Background (), & grpc.StreamDesc {}, nil , "some_method" , testStreamer )
500
+ require .NoError (t , err )
501
+
502
+ err = clientStream .RecvMsg (nil )
503
+ require .ErrorIs (t , err , recvMsgErr )
504
+
505
+ require .Equal (t , retryCount , attempt )
506
+ }
0 commit comments