diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 5467fe9715a3..c31ed87b1676 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1499,13 +1499,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { case "grpc-message": grpcMessage = decodeGrpcMessage(hf.Value) case ":status": - if hf.Value == "200" { - httpStatusErr = "" - statusCode := 200 - httpStatusCode = &statusCode - break - } - c, err := strconv.ParseInt(hf.Value, 10, 32) if err != nil { se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err)) @@ -1513,7 +1506,19 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { return } statusCode := int(c) + if statusCode >= 100 && statusCode < 200 { + if endStream { + se := status.New(codes.Internal, fmt.Sprintf( + "protocol error: informational header with status code %d must not have END_STREAM set", statusCode)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + } + return + } httpStatusCode = &statusCode + if statusCode == 200 { + httpStatusErr = "" + break + } httpStatusErr = fmt.Sprintf( "unexpected HTTP status code received from server: %d (%s)", diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index b8f97c3c9464..d0c8a88d6abf 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -3125,3 +3125,82 @@ func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) { t.Fatalf("RST frame received earlier than expected by duration: %v", want-got) } } + +// TestClientTransport_Handle1xxHeaders validates that 1xx HTTP status headers +// are ignored and treated as a protocol error if END_STREAM is set. +func (s) TestClientTransport_Handle1xxHeaders(t *testing.T) { + testStream := func() *ClientStream { + return &ClientStream{ + Stream: &Stream{ + buf: &recvBuffer{ + c: make(chan recvMsg), + mu: sync.Mutex{}, + }, + }, + done: make(chan struct{}), + headerChan: make(chan struct{}), + } + } + + testClient := func(ts *ClientStream) *http2Client { + return &http2Client{ + mu: sync.Mutex{}, + activeStreams: map[uint32]*ClientStream{ + 0: ts, + }, + controlBuf: newControlBuffer(make(<-chan struct{})), + } + } + + for _, test := range []struct { + name string + metaHeaderFrame *http2.MetaHeadersFrame + httpFlags http2.Flags + wantStatus *status.Status + }{ + { + name: "1xx with END_STREAM is error", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: ":status", Value: "100"}, + }, + }, + httpFlags: http2.FlagHeadersEndStream, + wantStatus: status.New( + codes.Internal, + "protocol error: informational header with status code 100 must not have END_STREAM set", + ), + }, + { + name: "1xx without END_STREAM is ignored", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: ":status", Value: "100"}, + }, + }, + httpFlags: 0, + wantStatus: nil, + }, + } { + t.Run(test.name, func(t *testing.T) { + ts := testStream() + s := testClient(ts) + + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, + Flags: test.httpFlags, + }, + } + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = %v, want %v", test.metaHeaderFrame, got, want) + } + }) + } +}