Skip to content

Commit a4d2de9

Browse files
authored
net: fix packet sequence when sequence overflows (#378)
1 parent 419e26d commit a4d2de9

File tree

4 files changed

+84
-23
lines changed

4 files changed

+84
-23
lines changed

pkg/proxy/net/compress.go

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,11 @@ const (
2525
CompressionZstd
2626
)
2727

28-
type rwStatus int
29-
30-
const (
31-
rwNone rwStatus = iota
32-
rwRead
33-
rwWrite
34-
)
35-
3628
const (
3729
// maxCompressedSize is the max uncompressed data size for a compressed packet.
3830
// Packets bigger than maxCompressedSize will be split into multiple compressed packets.
3931
// MySQL has 16K for the first packet. The rest packets and MySQL Connector/J are 16M.
40-
// Two restrictions for the length:
41-
// - it should be smaller than 16M so that the length can fit in the 3 byte field in the header.
42-
// - it should be larger than 4M so that the compressed sequence can fit in the 3 byte field when max_allowed_packet is 1G.
32+
// It should be smaller than 16M so that the length can fit in the 3 byte field in the header.
4333
maxCompressedSize = 1<<24 - 1
4434
// minCompressSize is the min uncompressed data size for compressed data.
4535
// Packets smaller than minCompressSize won't be compressed.
@@ -83,26 +73,25 @@ func newCompressedReadWriter(rw packetReadWriter, algorithm CompressAlgorithm, z
8373
}
8474
}
8575

86-
func (crw *compressedReadWriter) SetSequence(sequence uint8) {
87-
crw.packetReadWriter.SetSequence(sequence)
76+
func (crw *compressedReadWriter) ResetSequence() {
77+
crw.packetReadWriter.ResetSequence()
8878
// Reset the compressed sequence before the next command.
89-
if sequence == 0 {
90-
crw.sequence = 0
91-
crw.rwStatus = rwNone
92-
}
79+
// Sequence wraps around once it hits 0xFF, so we need ResetSequence() to know that it's reset instead of overflow.
80+
crw.sequence = 0
81+
crw.rwStatus = rwNone
9382
}
9483

84+
// BeginRW implements packetReadWriter.BeginRW.
9585
// Uncompressed sequence of MySQL doesn't follow the spec: it's set to the compressed sequence when
9686
// the client/server begins reading or writing.
97-
func (crw *compressedReadWriter) beginRW(status rwStatus) {
87+
func (crw *compressedReadWriter) BeginRW(status rwStatus) {
9888
if crw.rwStatus != status {
9989
crw.packetReadWriter.SetSequence(crw.sequence)
10090
crw.rwStatus = status
10191
}
10292
}
10393

10494
func (crw *compressedReadWriter) Read(p []byte) (n int, err error) {
105-
crw.beginRW(rwRead)
10695
// Read from the connection to fill the buffer if the buffer is empty.
10796
if crw.readBuffer.Len() == 0 {
10897
if err = crw.readFromConn(); err != nil {
@@ -156,7 +145,6 @@ func (crw *compressedReadWriter) readFromConn() error {
156145
}
157146

158147
func (crw *compressedReadWriter) Write(data []byte) (n int, err error) {
159-
crw.beginRW(rwWrite)
160148
for {
161149
remainingLen := maxCompressedSize - crw.writeBuffer.Len()
162150
if len(data) <= remainingLen {
@@ -234,7 +222,6 @@ func (crw *compressedReadWriter) DirectWrite(data []byte) (n int, err error) {
234222
// Peek won't be used.
235223
// Notice: the peeked data may be discarded if an error is returned.
236224
func (crw *compressedReadWriter) Peek(n int) (data []byte, err error) {
237-
crw.beginRW(rwRead)
238225
for crw.readBuffer.Len() < n {
239226
if err = crw.readFromConn(); err != nil {
240227
return
@@ -247,7 +234,6 @@ func (crw *compressedReadWriter) Peek(n int) (data []byte, err error) {
247234

248235
// Discard won't be used.
249236
func (crw *compressedReadWriter) Discard(n int) (d int, err error) {
250-
crw.beginRW(rwRead)
251237
for crw.readBuffer.Len() < n {
252238
if err = crw.readFromConn(); err != nil {
253239
return

pkg/proxy/net/compress_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ func TestReadWriteError(t *testing.T) {
235235
}
236236

237237
func fillAndWrite(t *testing.T, crw *compressedReadWriter, b byte, length int) {
238+
crw.BeginRW(rwWrite)
238239
data := fillData(b, length)
239240
_, err := crw.Write(data)
240241
require.NoError(t, err)
@@ -250,6 +251,7 @@ func fillData(b byte, length int) []byte {
250251
}
251252

252253
func readAndCheck(t *testing.T, crw *compressedReadWriter, b byte, length int) {
254+
crw.BeginRW(rwRead)
253255
data := make([]byte, length)
254256
_, err := io.ReadFull(crw, data)
255257
require.NoError(t, err)

pkg/proxy/net/packetio.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ const (
5050
defaultReaderSize = 16 * 1024
5151
)
5252

53+
type rwStatus int
54+
55+
const (
56+
rwNone rwStatus = iota
57+
rwRead
58+
rwWrite
59+
)
60+
5361
// packetReadWriter acts like a net.Conn with read and write buffer.
5462
type packetReadWriter interface {
5563
net.Conn
@@ -65,6 +73,10 @@ type packetReadWriter interface {
6573
IsPeerActive() bool
6674
SetSequence(uint8)
6775
Sequence() uint8
76+
// ResetSequence is called before executing a command.
77+
ResetSequence()
78+
// BeginRW is called before reading or writing packets.
79+
BeginRW(status rwStatus)
6880
}
6981

7082
var _ packetReadWriter = (*basicReadWriter)(nil)
@@ -123,6 +135,13 @@ func (brw *basicReadWriter) OutBytes() uint64 {
123135
return brw.outBytes
124136
}
125137

138+
func (brw *basicReadWriter) BeginRW(rwStatus) {
139+
}
140+
141+
func (brw *basicReadWriter) ResetSequence() {
142+
brw.sequence = 0
143+
}
144+
126145
func (brw *basicReadWriter) TLSConnectionState() tls.ConnectionState {
127146
return tls.ConnectionState{}
128147
}
@@ -190,7 +209,7 @@ func (p *PacketIO) RemoteAddr() net.Addr {
190209
}
191210

192211
func (p *PacketIO) ResetSequence() {
193-
p.readWriter.SetSequence(0)
212+
p.readWriter.ResetSequence()
194213
}
195214

196215
// GetSequence is used in tests to assert that the sequences on the client and server are equal.
@@ -219,6 +238,7 @@ func (p *PacketIO) readOnePacket() ([]byte, bool, error) {
219238

220239
// ReadPacket reads data and removes the header
221240
func (p *PacketIO) ReadPacket() (data []byte, err error) {
241+
p.readWriter.BeginRW(rwRead)
222242
for more := true; more; {
223243
var buf []byte
224244
buf, more, err = p.readOnePacket()
@@ -262,6 +282,7 @@ func (p *PacketIO) writeOnePacket(data []byte) (int, bool, error) {
262282

263283
// WritePacket writes data without a header
264284
func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) {
285+
p.readWriter.BeginRW(rwWrite)
265286
for more := true; more; {
266287
var n int
267288
n, more, err = p.writeOnePacket(data)

pkg/proxy/net/packetio_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,3 +362,55 @@ func TestProxyTLSCompress(t *testing.T) {
362362
}
363363
}
364364
}
365+
366+
// Test the sequence is correct with the compression protocol.
367+
func TestPacketSequence(t *testing.T) {
368+
write := func(p *PacketIO, flush bool) {
369+
require.NoError(t, p.WritePacket([]byte{0}, flush))
370+
}
371+
read := func(p *PacketIO) {
372+
_, err := p.ReadPacket()
373+
require.NoError(t, err)
374+
}
375+
loops := 1024
376+
testTCPConn(t,
377+
func(t *testing.T, cli *PacketIO) {
378+
require.NoError(t, cli.SetCompressionAlgorithm(CompressionZlib, 0))
379+
read(cli)
380+
// uncompressed sequence = compressed sequence
381+
write(cli, false)
382+
write(cli, true)
383+
// uncompressed sequence wraps around (1000 writes + 1 flush)
384+
for i := 0; i < loops; i++ {
385+
write(cli, false)
386+
}
387+
require.NoError(t, cli.Flush())
388+
// compressed sequence wraps around (1000 writes + 1000 flushes)
389+
for i := 0; i < loops; i++ {
390+
write(cli, true)
391+
}
392+
// reset sequence
393+
cli.ResetSequence()
394+
write(cli, true)
395+
},
396+
func(t *testing.T, srv *PacketIO) {
397+
require.NoError(t, srv.SetCompressionAlgorithm(CompressionZlib, 0))
398+
write(srv, true)
399+
// uncompressed sequence = compressed sequence
400+
read(srv)
401+
read(srv)
402+
// uncompressed sequence wraps around
403+
for i := 0; i < loops; i++ {
404+
read(srv)
405+
}
406+
// compressed sequence wraps around
407+
for i := 0; i < loops; i++ {
408+
read(srv)
409+
}
410+
// reset sequence
411+
srv.ResetSequence()
412+
read(srv)
413+
},
414+
1,
415+
)
416+
}

0 commit comments

Comments
 (0)