@@ -35,6 +35,7 @@ import (
35
35
"google.golang.org/grpc/codes"
36
36
"google.golang.org/grpc/mem"
37
37
"google.golang.org/grpc/metadata"
38
+ "google.golang.org/grpc/stats"
38
39
"google.golang.org/grpc/status"
39
40
"google.golang.org/protobuf/proto"
40
41
"google.golang.org/protobuf/protoadapt"
@@ -246,7 +247,26 @@ type handleStreamTest struct {
246
247
ht * serverHandlerTransport
247
248
}
248
249
249
- func newHandleStreamTest (t * testing.T ) * handleStreamTest {
250
+ type mockStatsHandler struct {
251
+ rpcStatsCh chan stats.RPCStats
252
+ }
253
+
254
+ func (h * mockStatsHandler ) TagRPC (ctx context.Context , _ * stats.RPCTagInfo ) context.Context {
255
+ return ctx
256
+ }
257
+
258
+ func (h * mockStatsHandler ) HandleRPC (_ context.Context , s stats.RPCStats ) {
259
+ h .rpcStatsCh <- s
260
+ }
261
+
262
+ func (h * mockStatsHandler ) TagConn (ctx context.Context , _ * stats.ConnTagInfo ) context.Context {
263
+ return ctx
264
+ }
265
+
266
+ func (h * mockStatsHandler ) HandleConn (context.Context , stats.ConnStats ) {
267
+ }
268
+
269
+ func newHandleStreamTest (t * testing.T , statsHandlers []stats.Handler ) * handleStreamTest {
250
270
bodyr , bodyw := io .Pipe ()
251
271
req := & http.Request {
252
272
ProtoMajor : 2 ,
@@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
260
280
Body : bodyr ,
261
281
}
262
282
rw := newTestHandlerResponseWriter ().(testHandlerResponseWriter )
263
- ht , err := NewServerHandlerTransport (rw , req , nil , mem .DefaultBufferPool ())
283
+ ht , err := NewServerHandlerTransport (rw , req , statsHandlers , mem .DefaultBufferPool ())
264
284
if err != nil {
265
285
t .Fatal (err )
266
286
}
@@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
273
293
}
274
294
275
295
func (s ) TestHandlerTransport_HandleStreams (t * testing.T ) {
276
- st := newHandleStreamTest (t )
296
+ st := newHandleStreamTest (t , nil )
277
297
handleStream := func (s * ServerStream ) {
278
298
if want := "/service/foo.bar" ; s .method != want {
279
299
t .Errorf ("stream method = %q; want %q" , s .method , want )
@@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
342
362
}
343
363
344
364
func handleStreamCloseBodyTest (t * testing.T , statusCode codes.Code , msg string ) {
345
- st := newHandleStreamTest (t )
365
+ st := newHandleStreamTest (t , nil )
346
366
347
367
handleStream := func (s * ServerStream ) {
348
368
s .WriteStatus (status .New (statusCode , msg ))
@@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
451
471
}
452
472
453
473
func testHandlerTransportHandleStreams (t * testing.T , handleStream func (st * handleStreamTest , s * ServerStream )) {
454
- st := newHandleStreamTest (t )
474
+ st := newHandleStreamTest (t , nil )
455
475
ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
456
476
t .Cleanup (cancel )
457
477
st .ht .HandleStreams (
@@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
483
503
t .Fatal (err )
484
504
}
485
505
486
- hst := newHandleStreamTest (t )
506
+ hst := newHandleStreamTest (t , nil )
487
507
handleStream := func (s * ServerStream ) {
488
508
s .WriteStatus (st )
489
509
}
@@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
506
526
checkHeaderAndTrailer (t , hst .rw , wantHeader , wantTrailer )
507
527
}
508
528
529
+ // Tests the use of stats handlers and ensures there are no data races while
530
+ // accessing trailers.
531
+ func (s ) TestHandlerTransport_HandleStreams_StatsHandlers (t * testing.T ) {
532
+ errDetails := []protoadapt.MessageV1 {
533
+ & epb.RetryInfo {
534
+ RetryDelay : & durationpb.Duration {Seconds : 60 },
535
+ },
536
+ & epb.ResourceInfo {
537
+ ResourceType : "foo bar" ,
538
+ ResourceName : "service.foo.bar" ,
539
+ Owner : "User" ,
540
+ },
541
+ }
542
+
543
+ statusCode := codes .ResourceExhausted
544
+ msg := "you are being throttled"
545
+ st , err := status .New (statusCode , msg ).WithDetails (errDetails ... )
546
+ if err != nil {
547
+ t .Fatal (err )
548
+ }
549
+
550
+ stBytes , err := proto .Marshal (st .Proto ())
551
+ if err != nil {
552
+ t .Fatal (err )
553
+ }
554
+ // Add mock stats handlers to exercise the stats handler code path.
555
+ statsHandler := & mockStatsHandler {
556
+ rpcStatsCh : make (chan stats.RPCStats , 2 ),
557
+ }
558
+ hst := newHandleStreamTest (t , []stats.Handler {statsHandler })
559
+ handleStream := func (s * ServerStream ) {
560
+ if err := s .SendHeader (metadata .New (map [string ]string {})); err != nil {
561
+ t .Error (err )
562
+ }
563
+ if err := s .SetTrailer (metadata .Pairs ("custom-trailer" , "Custom trailer value" )); err != nil {
564
+ t .Error (err )
565
+ }
566
+ s .WriteStatus (st )
567
+ }
568
+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
569
+ defer cancel ()
570
+ hst .ht .HandleStreams (
571
+ ctx , func (s * ServerStream ) { go handleStream (s ) },
572
+ )
573
+ wantHeader := http.Header {
574
+ "Date" : nil ,
575
+ "Content-Type" : {"application/grpc" },
576
+ "Trailer" : {"Grpc-Status" , "Grpc-Message" , "Grpc-Status-Details-Bin" },
577
+ }
578
+ wantTrailer := http.Header {
579
+ "Grpc-Status" : {fmt .Sprint (uint32 (statusCode ))},
580
+ "Grpc-Message" : {encodeGrpcMessage (msg )},
581
+ "Grpc-Status-Details-Bin" : {encodeBinHeader (stBytes )},
582
+ "Custom-Trailer" : []string {"Custom trailer value" },
583
+ }
584
+
585
+ checkHeaderAndTrailer (t , hst .rw , wantHeader , wantTrailer )
586
+ wantStatTypes := []stats.RPCStats {& stats.OutHeader {}, & stats.OutTrailer {}}
587
+ for _ , wantType := range wantStatTypes {
588
+ select {
589
+ case <- ctx .Done ():
590
+ t .Fatal ("Context timed out waiting for statsHandler.HandleRPC() to be called." )
591
+ case s := <- statsHandler .rpcStatsCh :
592
+ if reflect .TypeOf (s ) != reflect .TypeOf (wantType ) {
593
+ t .Fatalf ("Received RPCStats of type %T, want %T" , s , wantType )
594
+ }
595
+ }
596
+ }
597
+ }
598
+
509
599
// TestHandlerTransport_Drain verifies that Drain() is not implemented
510
600
// by `serverHandlerTransport`.
511
601
func (s ) TestHandlerTransport_Drain (t * testing.T ) {
512
602
defer func () { recover () }()
513
- st := newHandleStreamTest (t )
603
+ st := newHandleStreamTest (t , nil )
514
604
st .ht .Drain ("whatever" )
515
605
t .Errorf ("serverHandlerTransport.Drain() should have panicked" )
516
606
}
0 commit comments