Skip to content

Commit 2fe9a4f

Browse files
committed
acquire header mutex while copying trailers
1 parent 0ebea3e commit 2fe9a4f

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

internal/transport/handler_server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
277277
if err == nil { // transport has not been closed
278278
// Note: The trailer fields are compressed with hpack after this call returns.
279279
// No WireLength field is set here.
280+
s.hdrMu.Lock()
280281
for _, sh := range ht.stats {
281282
sh.HandleRPC(s.Context(), &stats.OutTrailer{
282283
Trailer: s.trailer.Copy(),
283284
})
284285
}
286+
s.hdrMu.Unlock()
285287
}
286288
ht.Close(errors.New("finished writing status"))
287289
return err

internal/transport/handler_server_test.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
"google.golang.org/grpc/codes"
3636
"google.golang.org/grpc/mem"
3737
"google.golang.org/grpc/metadata"
38+
"google.golang.org/grpc/stats"
3839
"google.golang.org/grpc/status"
3940
"google.golang.org/protobuf/proto"
4041
"google.golang.org/protobuf/protoadapt"
@@ -246,6 +247,22 @@ type handleStreamTest struct {
246247
ht *serverHandlerTransport
247248
}
248249

250+
type mockStatsHandler struct{}
251+
252+
func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
253+
return ctx
254+
}
255+
256+
func (h *mockStatsHandler) HandleRPC(context.Context, stats.RPCStats) {
257+
}
258+
259+
func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
260+
return ctx
261+
}
262+
263+
func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
264+
}
265+
249266
func newHandleStreamTest(t *testing.T) *handleStreamTest {
250267
bodyr, bodyw := io.Pipe()
251268
req := &http.Request{
@@ -260,7 +277,12 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
260277
Body: bodyr,
261278
}
262279
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
263-
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
280+
// Add mock stats handlers to exercise the stats handler code path.
281+
statsHandlers := make([]stats.Handler, 0, 5)
282+
for range 5 {
283+
statsHandlers = append(statsHandlers, &mockStatsHandler{})
284+
}
285+
ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
264286
if err != nil {
265287
t.Fatal(err)
266288
}
@@ -485,6 +507,12 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
485507

486508
hst := newHandleStreamTest(t)
487509
handleStream := func(s *ServerStream) {
510+
if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
511+
t.Error(err)
512+
}
513+
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
514+
t.Error(err)
515+
}
488516
s.WriteStatus(st)
489517
}
490518
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
@@ -501,6 +529,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
501529
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
502530
"Grpc-Message": {encodeGrpcMessage(msg)},
503531
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
532+
"Custom-Trailer": []string{"Custom trailer value"},
504533
}
505534

506535
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)

0 commit comments

Comments
 (0)