Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status
if err == nil { // transport has not been closed
// Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here.
s.hdrMu.Lock()
Copy link
Contributor Author

@arjan-bal arjan-bal Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're holding two locks here, this and ht.writeStatusMu (acquired at line 229). ht.writeStatusMu is only referenced in this method, so there shouldn't be a chance of deadlocks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that hdrMu is also already taken on 249, although, I'm not sure if that closure is run in the current goroutine or another one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The callback is executed in an event loop in a separate goroutine:

func (ht *serverHandlerTransport) runStream() {
for {
select {
case fn := <-ht.writes:
fn()
case <-ht.closedCh:
return
}
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about calling the stats handler in the callback above, but I noticed that the http2_server transport also schedules the network write in the background and calls the stats handlers.

t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true)
for _, sh := range t.stats {
// Note: The trailer fields are compressed with hpack after this call returns.
// No WireLength field is set here.
sh.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
}

for _, sh := range ht.stats {
sh.HandleRPC(s.Context(), &stats.OutTrailer{
Trailer: s.trailer.Copy(),
})
}
s.hdrMu.Unlock()
}
ht.Close(errors.New("finished writing status"))
return err
Expand Down
104 changes: 97 additions & 7 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/mem"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/stats"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/protoadapt"
Expand Down Expand Up @@ -246,7 +247,26 @@ type handleStreamTest struct {
ht *serverHandlerTransport
}

func newHandleStreamTest(t *testing.T) *handleStreamTest {
type mockStatsHandler struct {
rpcStatsCh chan stats.RPCStats
}

func (h *mockStatsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context {
return ctx
}

func (h *mockStatsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
h.rpcStatsCh <- s
}

func (h *mockStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context {
return ctx
}

func (h *mockStatsHandler) HandleConn(context.Context, stats.ConnStats) {
}

func newHandleStreamTest(t *testing.T, statsHandlers []stats.Handler) *handleStreamTest {
bodyr, bodyw := io.Pipe()
req := &http.Request{
ProtoMajor: 2,
Expand All @@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
Body: bodyr,
}
rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
ht, err := NewServerHandlerTransport(rw, req, statsHandlers, mem.DefaultBufferPool())
if err != nil {
t.Fatal(err)
}
Expand All @@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
}

func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st := newHandleStreamTest(t)
st := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) {
if want := "/service/foo.bar"; s.method != want {
t.Errorf("stream method = %q; want %q", s.method, want)
Expand Down Expand Up @@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
}

func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
st := newHandleStreamTest(t)
st := newHandleStreamTest(t, nil)

handleStream := func(s *ServerStream) {
s.WriteStatus(status.New(statusCode, msg))
Expand Down Expand Up @@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
}

func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
st := newHandleStreamTest(t)
st := newHandleStreamTest(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
t.Cleanup(cancel)
st.ht.HandleStreams(
Expand Down Expand Up @@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
t.Fatal(err)
}

hst := newHandleStreamTest(t)
hst := newHandleStreamTest(t, nil)
handleStream := func(s *ServerStream) {
s.WriteStatus(st)
}
Expand All @@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
}

// Tests the use of stats handlers and ensures there are no data races while
// accessing trailers.
func (s) TestHandlerTransport_HandleStreams_StatsHandlers(t *testing.T) {
errDetails := []protoadapt.MessageV1{
&epb.RetryInfo{
RetryDelay: &durationpb.Duration{Seconds: 60},
},
&epb.ResourceInfo{
ResourceType: "foo bar",
ResourceName: "service.foo.bar",
Owner: "User",
},
}

statusCode := codes.ResourceExhausted
msg := "you are being throttled"
st, err := status.New(statusCode, msg).WithDetails(errDetails...)
if err != nil {
t.Fatal(err)
}

stBytes, err := proto.Marshal(st.Proto())
if err != nil {
t.Fatal(err)
}
// Add mock stats handlers to exercise the stats handler code path.
statsHandler := &mockStatsHandler{
rpcStatsCh: make(chan stats.RPCStats, 2),
}
hst := newHandleStreamTest(t, []stats.Handler{statsHandler})
handleStream := func(s *ServerStream) {
if err := s.SendHeader(metadata.New(map[string]string{})); err != nil {
t.Error(err)
}
if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
t.Error(err)
}
Comment on lines +563 to +565
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to test this in a new test instead of in an existing test that's intended for testing error details?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to use a new test. My thought process was that modifying existing tests would give us better coverage for interactions b/w different features.

s.WriteStatus(st)
}
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
hst.ht.HandleStreams(
ctx, func(s *ServerStream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
"Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
}
wantTrailer := http.Header{
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {encodeGrpcMessage(msg)},
"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
"Custom-Trailer": []string{"Custom trailer value"},
}

checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
wantStatTypes := []stats.RPCStats{&stats.OutHeader{}, &stats.OutTrailer{}}
for _, wantType := range wantStatTypes {
select {
case <-ctx.Done():
t.Fatal("Context timed out waiting for statsHandler.HandleRPC() to be called.")
case s := <-statsHandler.rpcStatsCh:
if reflect.TypeOf(s) != reflect.TypeOf(wantType) {
t.Fatalf("Received RPCStats of type %T, want %T", s, wantType)
}
}
}
}

// TestHandlerTransport_Drain verifies that Drain() is not implemented
// by `serverHandlerTransport`.
func (s) TestHandlerTransport_Drain(t *testing.T) {
defer func() { recover() }()
st := newHandleStreamTest(t)
st := newHandleStreamTest(t, nil)
st.ht.Drain("whatever")
t.Errorf("serverHandlerTransport.Drain() should have panicked")
}
Expand Down