@@ -23,14 +23,13 @@ import (
2323 "io"
2424 "net"
2525 "sync"
26-
27- "google.golang.org/grpc/codes"
28- "google.golang.org/grpc/status"
2926)
3027
3128const (
32- messageHeaderLength = 10
33- messageLengthMax = 4 << 20
29+ messageHeaderLength = 10
30+ MinMessageLengthLimit = 4 << 10
31+ MaxMessageLengthLimit = 4 << 22
32+ DefaultMessageLengthLimit = 4 << 20
3433)
3534
3635type messageType uint8
@@ -96,18 +95,23 @@ func writeMessageHeader(w io.Writer, p []byte, mh messageHeader) error {
9695var buffers sync.Pool
9796
9897type channel struct {
99- conn net.Conn
100- bw * bufio.Writer
101- br * bufio.Reader
102- hrbuf [messageHeaderLength ]byte // avoid alloc when reading header
103- hwbuf [messageHeaderLength ]byte
98+ conn net.Conn
99+ bw * bufio.Writer
100+ br * bufio.Reader
101+ hrbuf [messageHeaderLength ]byte // avoid alloc when reading header
102+ hwbuf [messageHeaderLength ]byte
103+ maxMsgLen int
104104}
105105
106- func newChannel (conn net.Conn ) * channel {
106+ func newChannel (conn net.Conn , maxMsgLen int ) * channel {
107+ if maxMsgLen == 0 {
108+ maxMsgLen = DefaultMessageLengthLimit
109+ }
107110 return & channel {
108- conn : conn ,
109- bw : bufio .NewWriter (conn ),
110- br : bufio .NewReader (conn ),
111+ conn : conn ,
112+ bw : bufio .NewWriter (conn ),
113+ br : bufio .NewReader (conn ),
114+ maxMsgLen : maxMsgLen ,
111115 }
112116}
113117
@@ -123,12 +127,12 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
123127 return messageHeader {}, nil , err
124128 }
125129
126- if mh .Length > uint32 (messageLengthMax ) {
130+ if maxMsgLen := ch . maxMsgLimit ( true ); mh .Length > uint32 (maxMsgLen ) {
127131 if _ , err := ch .br .Discard (int (mh .Length )); err != nil {
128132 return mh , nil , fmt .Errorf ("failed to discard after receiving oversized message: %w" , err )
129133 }
130134
131- return mh , nil , status . Errorf ( codes . ResourceExhausted , "message length %v exceed maximum message size of %v" , mh .Length , messageLengthMax )
135+ return mh , nil , OversizedMessageError ( int ( mh .Length ), maxMsgLen )
132136 }
133137
134138 var p []byte
@@ -143,8 +147,10 @@ func (ch *channel) recv() (messageHeader, []byte, error) {
143147}
144148
145149func (ch * channel ) send (streamID uint32 , t messageType , flags uint8 , p []byte ) error {
146- if len (p ) > messageLengthMax {
147- return OversizedMessageError (len (p ))
150+ if maxMsgLen := ch .maxMsgLimit (false ); maxMsgLen != 0 {
151+ if len (p ) > maxMsgLen {
152+ return OversizedMessageError (len (p ), maxMsgLen )
153+ }
148154 }
149155
150156 if err := writeMessageHeader (ch .bw , ch .hwbuf [:], messageHeader {Length : uint32 (len (p )), StreamID : streamID , Type : t , Flags : flags }); err != nil {
@@ -180,3 +186,22 @@ func (ch *channel) getmbuf(size int) []byte {
180186func (ch * channel ) putmbuf (p []byte ) {
181187 buffers .Put (& p )
182188}
189+
190+ func (ch * channel ) maxMsgLimit (recv bool ) int {
191+ if ch .maxMsgLen == 0 && recv {
192+ return DefaultMessageLengthLimit
193+ }
194+ return ch .maxMsgLen
195+ }
196+
197+ func clampWireMessageLimit (maxMsgLen int ) int {
198+ switch {
199+ case maxMsgLen == 0 :
200+ return 0
201+ case maxMsgLen < MinMessageLengthLimit :
202+ return MinMessageLengthLimit
203+ case maxMsgLen > MaxMessageLengthLimit :
204+ return MaxMessageLengthLimit
205+ }
206+ return maxMsgLen
207+ }
0 commit comments