Skip to content

Commit 5ed7cf6

Browse files
authored
transport: ensure header mutex is held while copying trailers in handler_server (#8519)
Fixes: #8514 The mutex that guards the trailers should be held while copying the trailers. We do lock the mutex in [the regular gRPC server transport](https://github.com/grpc/grpc-go/blob/9ac0ec87ca2ecc66b3c0c084708aef768637aef6/internal/transport/http2_server.go#L1140-L1142), but have missed it in the std lib http/2 transport. The only place where a write happens is `writeStatus()` is when the status contains a proto. https://github.com/grpc/grpc-go/blob/4375c784450aa7e43ff15b8b2879c896d0917130/internal/transport/handler_server.go#L251-L252 RELEASE NOTES: * transport: Fix a data race while copying headers for stats handlers in the std lib http2 server transport.
1 parent fa0d658 commit 5ed7cf6

File tree

2 files changed

+99
-7
lines changed

2 files changed

+99
-7
lines changed

internal/transport/handler_server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
277277
if err == nil { // transport has not been closed
278278
// Note: The trailer fields are compressed with hpack after this call returns.
279279
// No WireLength field is set here.
280+
s.hdrMu.Lock()
280281
for _, sh := range ht.stats {
281282
sh.HandleRPC(s.Context(), &stats.OutTrailer{
282283
Trailer: s.trailer.Copy(),
283284
})
284285
}
286+
s.hdrMu.Unlock()
285287
}
286288
ht.Close(errors.New("finished writing status"))
287289
return err

internal/transport/handler_server_test.go

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"google.golang.org/grpc/codes"
3636
"google.golang.org/grpc/mem"
3737
"google.golang.org/grpc/metadata"
38+
"google.golang.org/grpc/stats"
3839
"google.golang.org/grpc/status"
3940
"google.golang.org/protobuf/proto"
4041
"google.golang.org/protobuf/protoadapt"
@@ -246,7 +247,26 @@ type handleStreamTest struct {
246247
ht *serverHandlerTransport
247248
}
248249

249-
func newHandleStreamTest(t *testing.T) *handleStreamTest {
250+
type mockStatsHandler struct {
251+
rpcStatsCh chan stats.RPCStats
252+
}
253+
254+
func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
255+
return ctx
256+
}
257+
258+
func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
259+
h.rpcStatsCh <- s
260+
}
261+
262+
func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
263+
return ctx
264+
}
265+
266+
func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
267+
}
268+
269+
func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest {
250270
bodyr, bodyw := io.Pipe()
251271
req := &http.Request{
252272
ProtoMajor: 2,
@@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
260280
Body: bodyr,
261281
}
262282
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
263-
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
283+
ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
264284
if err != nil {
265285
t.Fatal(err)
266286
}
@@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
273293
}
274294

275295
func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
276-
st := newHandleStreamTest(t)
296+
st := newHandleStreamTest(t, nil)
277297
handleStream := func(s *ServerStream) {
278298
if want := "/service/foo.bar"; s.method != want {
279299
t.Errorf("stream method = %q; want %q", s.method, want)
@@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
342362
}
343363

344364
func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
345-
st := newHandleStreamTest(t)
365+
st := newHandleStreamTest(t, nil)
346366

347367
handleStream := func(s *ServerStream) {
348368
s.WriteStatus(status.New(statusCode, msg))
@@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
451471
}
452472

453473
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
454-
st := newHandleStreamTest(t)
474+
st := newHandleStreamTest(t, nil)
455475
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
456476
t.Cleanup(cancel)
457477
st.ht.HandleStreams(
@@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
483503
t.Fatal(err)
484504
}
485505

486-
hst := newHandleStreamTest(t)
506+
hst := newHandleStreamTest(t, nil)
487507
handleStream := func(s *ServerStream) {
488508
s.WriteStatus(st)
489509
}
@@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
506526
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
507527
}
508528

529+
// Tests the use of stats handlers and ensures there are no data races while
530+
// accessing trailers.
531+
func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) {
532+
errDetails := []protoadapt.MessageV1{
533+
&epb.RetryInfo{
534+
RetryDelay: &durationpb.Duration{Seconds: 60},
535+
},
536+
&epb.ResourceInfo{
537+
ResourceType: "foo bar",
538+
ResourceName: "service.foo.bar",
539+
Owner: "User",
540+
},
541+
}
542+
543+
statusCode := codes.ResourceExhausted
544+
msg := "you are being throttled"
545+
st, err := status.New(statusCode, msg).WithDetails(errDetails...)
546+
if err != nil {
547+
t.Fatal(err)
548+
}
549+
550+
stBytes, err := proto.Marshal(st.Proto())
551+
if err != nil {
552+
t.Fatal(err)
553+
}
554+
// Add mock stats handlers to exercise the stats handler code path.
555+
statsHandler := &mockStatsHandler{
556+
rpcStatsCh: make(chan stats.RPCStats, 2),
557+
}
558+
hst := newHandleStreamTest(t, []stats.Handler{statsHandler})
559+
handleStream := func(s *ServerStream) {
560+
if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
561+
t.Error(err)
562+
}
563+
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
564+
t.Error(err)
565+
}
566+
s.WriteStatus(st)
567+
}
568+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
569+
defer cancel()
570+
hst.ht.HandleStreams(
571+
ctx, func(s *ServerStream) { go handleStream(s) },
572+
)
573+
wantHeader := http.Header{
574+
"Date": nil,
575+
"Content-Type": {"application/grpc"},
576+
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
577+
}
578+
wantTrailer := http.Header{
579+
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
580+
"Grpc-Message": {encodeGrpcMessage(msg)},
581+
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
582+
"Custom-Trailer": []string{"Custom trailer value"},
583+
}
584+
585+
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
586+
wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}}
587+
for _, wantType := range wantStatTypes {
588+
select {
589+
case <-ctx.Done():
590+
t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.")
591+
case s := <-statsHandler.rpcStatsCh:
592+
if reflect.TypeOf(s) != reflect.TypeOf(wantType) {
593+
t.Fatalf("Received RPCStats of type %T, want %T", s, wantType)
594+
}
595+
}
596+
}
597+
}
598+
509599
// TestHandlerTransport_Drain verifies that Drain() is not implemented
510600
// by `serverHandlerTransport`.
511601
func (s) TestHandlerTransport_Drain(t *testing.T) {
512602
defer func() { recover() }()
513-
st := newHandleStreamTest(t)
603+
st := newHandleStreamTest(t, nil)
514604
st.ht.Drain("whatever")
515605
t.Errorf("serverHandlerTransport.Drain() should have panicked")
516606
}

0 commit comments

Comments
 (0)