From 2fe9a4f87c1ba792a316a8825d04fcf5ec0acf85 Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Mon, 18 Aug 2025 16:21:26 +0530 Subject: [PATCH 1/2] acquire header mutex while copying trailers --- internal/transport/handler_server.go | 2 ++ internal/transport/handler_server_test.go | 31 ++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 3dea23573518..d954a64c38f4 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status if err == nil { // transport has not been closed // Note: The trailer fields are compressed with hpack after this call returns. // No WireLength field is set here. + s.hdrMu.Lock() for _, sh := range ht.stats { sh.HandleRPC(s.Context(), &stats.OutTrailer{ Trailer: s.trailer.Copy(), }) } + s.hdrMu.Unlock() } ht.Close(errors.New("finished writing status")) return err diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 911022834322..0f2c9ca0d245 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/protoadapt" @@ -246,6 +247,22 @@ type handleStreamTest struct { ht *serverHandlerTransport } +type mockStatsHandler struct{} + +func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { + return ctx +} + +func (h *mockStatsHandler) HandleRPC(context.Context, stats.RPCStats) { +} + +func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { + return ctx +} + +func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) { +} + func newHandleStreamTest(t *testing.T) *handleStreamTest { bodyr, bodyw := io.Pipe() req := &http.Request{ @@ -260,7 +277,12 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) - ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool()) + // Add mock stats handlers to exercise the stats handler code path. + statsHandlers := make([]stats.Handler, 0, 5) + for range 5 { + statsHandlers = append(statsHandlers, &mockStatsHandler{}) + } + ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool()) if err != nil { t.Fatal(err) } @@ -485,6 +507,12 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { hst := newHandleStreamTest(t) handleStream := func(s *ServerStream) { + if err := s.SendHeader(metadata.New(map[string]string{})); err != nil { + t.Error(err) + } + if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil { + t.Error(err) + } s.WriteStatus(st) } ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) @@ -501,6 +529,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, "Grpc-Message": {encodeGrpcMessage(msg)}, "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, + "Custom-Trailer": []string{"Custom trailer value"}, } checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) From 31a6f265a06537ed46d789b2413d233e771e532e Mon Sep 17 00:00:00 2001 From: Arjan Bal Date: Tue, 19 Aug 2025 15:44:05 +0530 Subject: [PATCH 2/2] Use a new test --- internal/transport/handler_server_test.go | 87 +++++++++++++++++++---- 1 file changed, 74 insertions(+), 13 deletions(-) diff --git a/internal/transport/handler_server_test.go b/internal/transport/handler_server_test.go index 0f2c9ca0d245..e64af27411da 100644 --- a/internal/transport/handler_server_test.go +++ b/internal/transport/handler_server_test.go @@ -247,13 +247,16 @@ type handleStreamTest struct { ht *serverHandlerTransport } -type mockStatsHandler struct{} +type mockStatsHandler struct { + rpcStatsCh chan stats.RPCStats +} func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } -func (h *mockStatsHandler) HandleRPC(context.Context, stats.RPCStats) { +func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { + h.rpcStatsCh <- s } func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { @@ -263,7 +266,7 @@ func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) co func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) { } -func newHandleStreamTest(t *testing.T) *handleStreamTest { +func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest { bodyr, bodyw := io.Pipe() req := &http.Request{ ProtoMajor: 2, @@ -277,11 +280,6 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { Body: bodyr, } rw := newTestHandlerResponseWriter().(testHandlerResponseWriter) - // Add mock stats handlers to exercise the stats handler code path. - statsHandlers := make([]stats.Handler, 0, 5) - for range 5 { - statsHandlers = append(statsHandlers, &mockStatsHandler{}) - } ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool()) if err != nil { t.Fatal(err) @@ -295,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest { } func (s) TestHandlerTransport_HandleStreams(t *testing.T) { - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) handleStream := func(s *ServerStream) { if want := "/service/foo.bar"; s.method != want { t.Errorf("stream method = %q; want %q", s.method, want) @@ -364,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) { } func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) { - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) handleStream := func(s *ServerStream) { s.WriteStatus(status.New(statusCode, msg)) @@ -473,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) { } func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) { - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) t.Cleanup(cancel) st.ht.HandleStreams( @@ -505,7 +503,59 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { t.Fatal(err) } - hst := newHandleStreamTest(t) + hst := newHandleStreamTest(t, nil) + handleStream := func(s *ServerStream) { + s.WriteStatus(st) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + hst.ht.HandleStreams( + ctx, func(s *ServerStream) { go handleStream(s) }, + ) + wantHeader := http.Header{ + "Date": nil, + "Content-Type": {"application/grpc"}, + "Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"}, + } + wantTrailer := http.Header{ + "Grpc-Status": {fmt.Sprint(uint32(statusCode))}, + "Grpc-Message": {encodeGrpcMessage(msg)}, + "Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)}, + } + + checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) +} + +// Tests the use of stats handlers and ensures there are no data races while +// accessing trailers. +func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) { + errDetails := []protoadapt.MessageV1{ + &epb.RetryInfo{ + RetryDelay: &durationpb.Duration{Seconds: 60}, + }, + &epb.ResourceInfo{ + ResourceType: "foo bar", + ResourceName: "service.foo.bar", + Owner: "User", + }, + } + + statusCode := codes.ResourceExhausted + msg := "you are being throttled" + st, err := status.New(statusCode, msg).WithDetails(errDetails...) + if err != nil { + t.Fatal(err) + } + + stBytes, err := proto.Marshal(st.Proto()) + if err != nil { + t.Fatal(err) + } + // Add mock stats handlers to exercise the stats handler code path. + statsHandler := &mockStatsHandler{ + rpcStatsCh: make(chan stats.RPCStats, 2), + } + hst := newHandleStreamTest(t, []stats.Handler{statsHandler}) handleStream := func(s *ServerStream) { if err := s.SendHeader(metadata.New(map[string]string{})); err != nil { t.Error(err) @@ -533,13 +583,24 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { } checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer) + wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}} + for _, wantType := range wantStatTypes { + select { + case <-ctx.Done(): + t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.") + case s := <-statsHandler.rpcStatsCh: + if reflect.TypeOf(s) != reflect.TypeOf(wantType) { + t.Fatalf("Received RPCStats of type %T, want %T", s, wantType) + } + } + } } // TestHandlerTransport_Drain verifies that Drain() is not implemented // by `serverHandlerTransport`. func (s) TestHandlerTransport_Drain(t *testing.T) { defer func() { recover() }() - st := newHandleStreamTest(t) + st := newHandleStreamTest(t, nil) st.ht.Drain("whatever") t.Errorf("serverHandlerTransport.Drain() should have panicked") }