diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 3dea23573518..d954a64c38f4 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status if err == nil { // transport has not been closed // Note: The trailer fields are compressed with hpack after this call returns. // No WireLength field is set here. + s.hdrMu.Lock() for _, sh := range ht.stats { sh.HandleRPC(s.Context(), &stats.OutTrailer{ Trailer: s.trailer.Copy(), }) } + s.hdrMu.Unlock() } ht.Close(errors.New("finished writing status")) return err diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 911022834322..e64af27411da 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt" @@ -246,7 +247,26 @@ type handleStreamTest struct { ht *serverHandlerTransport } -func newHandleStreamTest(t *testing.T) *handleStreamTest { +type mockStatsHandler struct { + rpcStatsCh chan stats.RPCStats +} + +func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + return ctx +} + +func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + h.rpcStatsCh <- s +} + +func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} + +func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) { +} + +func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest { bodyr, bodyw := io.Pipe() req := &http.Request{ ProtoMajor: 2, @@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) - ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool()) + ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool()) if err != nil { t.Fatal(err) } @@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { } func (s) TestHandlerTransport_HandleStreams(t *testing.T) { - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) handleStream := func(s *ServerStream) { if want := "/service/foo.bar"; s.method != want { t.Errorf("stream method = %q; want %q", s.method, want) @@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { } func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) handleStream := func(s *ServerStream) { s.WriteStatus(status.New(statusCode, msg)) @@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { } func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) { - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) t.Cleanup(cancel) st.ht.HandleStreams( @@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { t.Fatal(err) } - hst := newHandleStreamTest(t) + hst := newHandleStreamTest(t, nil) handleStream := func(s *ServerStream) { s.WriteStatus(st) } @@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) } +// Tests the use of stats handlers and ensures there are no data races while +// accessing trailers. +func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) { + errDetails := []protoadapt.MessageV1{ + &epb.RetryInfo{ + RetryDelay: &durationpb.Duration{Seconds: 60}, + }, + &epb.ResourceInfo{ + ResourceType: "foo bar", + ResourceName: "service.foo.bar", + Owner: "User", + }, + } + + statusCode := codes.ResourceExhausted + msg := "you are being throttled" + st, err := status.New(statusCode, msg).WithDetails(errDetails...) + if err != nil { + t.Fatal(err) + } + + stBytes, err := proto.Marshal(st.Proto()) + if err != nil { + t.Fatal(err) + } + // Add mock stats handlers to exercise the stats handler code path. + statsHandler := &mockStatsHandler{ + rpcStatsCh: make(chan stats.RPCStats, 2), + } + hst := newHandleStreamTest(t, []stats.Handler{statsHandler}) + handleStream := func(s *ServerStream) { + if err := s.SendHeader(metadata.New(map[string]string{})); err != nil { + t.Error(err) + } + if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil { + t.Error(err) + } + s.WriteStatus(st) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + hst.ht.HandleStreams( + ctx, func(s *ServerStream) { go handleStream(s) }, + ) + wantHeader := http.Header{ + "Date": nil, + "Content-Type": {"application/grpc"}, + "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, + } + wantTrailer := http.Header{ + "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, + "Grpc-Message": {encodeGrpcMessage(msg)}, + "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, + "Custom-Trailer": []string{"Custom trailer value"}, + } + + checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) + wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}} + for _, wantType := range wantStatTypes { + select { + case <-ctx.Done(): + t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.") + case s := <-statsHandler.rpcStatsCh: + if reflect.TypeOf(s) != reflect.TypeOf(wantType) { + t.Fatalf("Received RPCStats of type %T, want %T", s, wantType) + } + } + } +} + // TestHandlerTransport_Drain verifies that Drain() is not implemented // by `serverHandlerTransport`. func (s) TestHandlerTransport_Drain(t *testing.T) { defer func() { recover() }() - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) st.ht.Drain("whatever") t.Errorf("serverHandlerTransport.Drain() should have panicked") }