diff --git a/internal/transport/client_stream.go b/internal/transport/client_stream.go index b73a9d2285fe..980452519ea7 100644 --- a/internal/transport/client_stream.go +++ b/internal/transport/client_stream.go @@ -144,3 +144,11 @@ func (s *ClientStream) TrailersOnly() bool { func (s *ClientStream) Status() *status.Status { return s.status } + +func (s *ClientStream) requestRead(n int) { + s.ct.adjustWindow(s, uint32(n)) +} + +func (s *ClientStream) updateWindow(n int) { + s.ct.updateWindow(s, uint32(n)) +} diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index aadc5d81c0ea..80ef2d00fbe7 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -387,6 +387,12 @@ func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) e return err } +func (ht *serverHandlerTransport) adjustWindow(*ServerStream, uint32) { +} + +func (ht *serverHandlerTransport) updateWindow(*ServerStream, uint32) { +} + func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) { // With this transport type there will be exactly 1 stream: this HTTP request. var cancel context.CancelFunc @@ -414,7 +420,6 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream Stream: Stream{ id: 0, // irrelevant ctx: ctx, - requestRead: func(int) {}, method: req.URL.Path, recvCompress: req.Header.Get("grpc-encoding"), contentSubtype: ht.contentSubtype, @@ -424,9 +429,10 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream headerWireLength: 0, // won't have access to header wire length until golang/go#18997. } s.Stream.buf.init() + s.readRequester = s s.trReader = transportReader{ reader: recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: &s.buf}, - windowHandler: func(int) {}, + windowHandler: s, } // readerDone is closed when the Body.Read-ing goroutine exits. diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 386801e00123..911d7e1ea4e3 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -493,9 +493,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientSt } s.Stream.buf.init() s.Stream.wq.init(defaultWriteQuota, s.done) - s.requestRead = func(n int) { - t.adjustWindow(s, uint32(n)) - } + s.readRequester = s // The client side stream context should have exactly the same life cycle with the user provided context. // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. // So we use the original context here instead of creating a copy. @@ -509,9 +507,7 @@ func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientSt s.Close(err) }, }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) - }, + windowHandler: s, } return s } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index b7fa73f06dd9..bcedac32fed5 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -640,9 +640,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade t.channelz.SocketMetrics.StreamsStarted.Add(1) t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano()) } - s.requestRead = func(n int) { - t.adjustWindow(s, uint32(n)) - } + s.readRequester = s s.ctxDone = s.ctx.Done() s.Stream.wq.init(defaultWriteQuota, s.ctxDone) s.trReader = transportReader{ @@ -651,9 +649,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade ctxDone: s.ctxDone, recv: &s.buf, }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) - }, + windowHandler: s, } // Register the stream with loopy. t.controlBuf.put(®isterStream{ diff --git a/internal/transport/server_stream.go b/internal/transport/server_stream.go index b203568fc349..ed6a13b7501a 100644 --- a/internal/transport/server_stream.go +++ b/internal/transport/server_stream.go @@ -179,3 +179,11 @@ func (s *ServerStream) SetTrailer(md metadata.MD) error { s.hdrMu.Unlock() return nil } + +func (s *ServerStream) requestRead(n int) { + s.st.adjustWindow(s, uint32(n)) +} + +func (s *ServerStream) updateWindow(n int) { + s.st.updateWindow(s, uint32(n)) +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 617f5ae04bcd..9565b400cd62 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -291,9 +291,7 @@ type Stream struct { recvCompress string sendCompress string - // Callback to state application's intentions to read data. This - // is used to adjust flow control, if needed. - requestRead func(int) + readRequester readRequester // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. @@ -310,6 +308,12 @@ type Stream struct { wq writeQuota } +// readRequester is used to state application's intentions to read data. This +// is used to adjust flow control, if needed. +type readRequester interface { + requestRead(int) +} + func (s *Stream) swapState(st streamState) streamState { return streamState(atomic.SwapUint32((*uint32)(&s.state), uint32(st))) } @@ -357,7 +361,7 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) { if er := s.trReader.er; er != nil { return er } - s.requestRead(len(header)) + s.readRequester.requestRead(len(header)) for len(header) != 0 { n, err := s.trReader.ReadMessageHeader(header) header = header[n:] @@ -380,7 +384,7 @@ func (s *Stream) read(n int) (data mem.BufferSlice, err error) { if er := s.trReader.er; er != nil { return nil, er } - s.requestRead(n) + s.readRequester.requestRead(n) for n != 0 { buf, err := s.trReader.Read(n) var bufLen int @@ -422,18 +426,24 @@ type transportReader struct { _ noCopy // The handler to control the window update procedure for both this // particular stream and the associated transport. - windowHandler func(int) + windowHandler windowHandler er error reader recvBufferReader } +// The handler to control the window update procedure for both this +// particular stream and the associated transport. +type windowHandler interface { + updateWindow(int) +} + func (t *transportReader) ReadMessageHeader(header []byte) (int, error) { n, err := t.reader.ReadMessageHeader(header) if err != nil { t.er = err return 0, err } - t.windowHandler(n) + t.windowHandler.updateWindow(n) return n, nil } @@ -443,7 +453,7 @@ func (t *transportReader) Read(n int) (mem.Buffer, error) { t.er = err return buf, err } - t.windowHandler(buf.Len()) + t.windowHandler.updateWindow(buf.Len()) return buf, nil } @@ -629,6 +639,8 @@ type internalServerTransport interface { write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error writeStatus(s *ServerStream, st *status.Status) error incrMsgRecv() + adjustWindow(s *ServerStream, n uint32) + updateWindow(s *ServerStream, n uint32) } // connectionErrorf creates an ConnectionError with the specified error description. diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 4e46e267afe8..d704ab8fa0a7 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -1856,8 +1856,8 @@ func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() s := &Stream{ - ctx: ctx, - requestRead: func(int) {}, + ctx: ctx, + readRequester: &fakeReadRequester{}, } s.buf.init() s.trReader = transportReader{ @@ -1866,7 +1866,9 @@ func (s) TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { ctxDone: s.ctx.Done(), recv: &s.buf, }, - windowHandler: func(int) {}, + windowHandler: &mockWindowUpdater{ + f: func(int) {}, + }, } testData := make([]byte, 1) testData[0] = 5 @@ -3163,7 +3165,7 @@ func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) { headerLen := 5 bytesRead := 0 s := Stream{ - requestRead: func(int) {}, + readRequester: &fakeReadRequester{}, } s.buf.init() recvBuffer := &s.buf @@ -3171,8 +3173,10 @@ func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) { reader: recvBufferReader{ recv: recvBuffer, }, - windowHandler: func(i int) { - bytesRead += i + windowHandler: &mockWindowUpdater{ + f: func(i int) { + bytesRead += i + }, }, } @@ -3476,3 +3480,16 @@ func (s) TestDeleteStreamMetricsIncrementedOnlyOnce(t *testing.T) { }) } } + +type fakeReadRequester struct { +} + +func (f *fakeReadRequester) requestRead(int) {} + +type mockWindowUpdater struct { + f func(int) +} + +func (m *mockWindowUpdater) updateWindow(n int) { + m.f(n) +}