Skip to content

Commit c7ea81a

Browse files
djshow832xhebox
andauthored
backend, net: optimize read/write connection by forwarding packets (#391)
Co-authored-by: xhe <[email protected]>
1 parent 5da388a commit c7ea81a

File tree

12 files changed

+753
-50
lines changed

12 files changed

+753
-50
lines changed

pkg/proxy/backend/authenticator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func (auth *Authenticator) readInitialHandshake(backendIO *pnet.PacketIO) (serve
311311
if serverPkt, err = backendIO.ReadPacket(); err != nil {
312312
return
313313
}
314-
if pnet.IsErrorPacket(serverPkt) {
314+
if pnet.IsErrorPacket(serverPkt[0]) {
315315
err = pnet.ParseErrorPacket(serverPkt)
316316
return
317317
}

pkg/proxy/backend/cmd_processor_exec.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,28 +98,32 @@ func forwardOnePacket(destIO, srcIO *pnet.PacketIO, flush bool) (data []byte, er
9898
}
9999

100100
func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO, request []byte) (uint16, error) {
101-
for {
102-
response, err := forwardOnePacket(clientIO, backendIO, false)
103-
if err != nil {
104-
return 0, err
101+
var serverStatus uint16
102+
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, length int) bool {
103+
switch {
104+
case pnet.IsErrorPacket(firstByte):
105+
return true
106+
case cp.capability&pnet.ClientDeprecateEOF == 0:
107+
return pnet.IsEOFPacket(firstByte, length)
108+
default:
109+
return pnet.IsResultSetOKPacket(firstByte, length)
105110
}
106-
if pnet.IsErrorPacket(response) {
111+
}, func(response []byte) error {
112+
switch {
113+
case pnet.IsErrorPacket(response[0]):
107114
if err := clientIO.Flush(); err != nil {
108-
return 0, err
109-
}
110-
return 0, cp.handleErrorPacket(response)
111-
}
112-
if cp.capability&pnet.ClientDeprecateEOF == 0 {
113-
if pnet.IsEOFPacket(response) {
114-
return cp.handleEOFPacket(request, response), clientIO.Flush()
115-
}
116-
} else {
117-
if pnet.IsResultSetOKPacket(response) {
118-
rs := cp.handleOKPacket(request, response)
119-
return rs.Status, clientIO.Flush()
115+
return err
120116
}
117+
return cp.handleErrorPacket(response)
118+
case cp.capability&pnet.ClientDeprecateEOF == 0:
119+
serverStatus = cp.handleEOFPacket(request, response)
120+
return clientIO.Flush()
121+
default:
122+
serverStatus = cp.handleOKPacket(request, response).Status
123+
return clientIO.Flush()
121124
}
122-
}
125+
})
126+
return serverStatus, err
123127
}
124128

125129
func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) error {
@@ -241,7 +245,7 @@ func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, req
241245
if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil {
242246
return 0, err
243247
}
244-
if pnet.IsEOFPacket(response) {
248+
if pnet.IsEOFPacket(response[0], len(response)) {
245249
break
246250
}
247251
}

pkg/proxy/backend/cmd_processor_query.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func (cp *CmdProcessor) readResultColumns(packetIO *pnet.PacketIO, result *gomys
7373
if data, err = packetIO.ReadPacket(); err != nil {
7474
return err
7575
}
76-
if !pnet.IsEOFPacket(data) {
76+
if !pnet.IsEOFPacket(data[0], len(data)) {
7777
return errors.WithStack(mysql.ErrMalformPacket)
7878
}
7979
result.Status = binary.LittleEndian.Uint16(data[3:])
@@ -103,19 +103,19 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql.
103103
return err
104104
}
105105
if cp.capability&pnet.ClientDeprecateEOF == 0 {
106-
if pnet.IsEOFPacket(data) {
106+
if pnet.IsEOFPacket(data[0], len(data)) {
107107
result.Status = binary.LittleEndian.Uint16(data[3:])
108108
break
109109
}
110110
} else {
111-
if pnet.IsResultSetOKPacket(data) {
111+
if pnet.IsResultSetOKPacket(data[0], len(data)) {
112112
rs := pnet.ParseOKPacket(data)
113113
result.Status = rs.Status
114114
break
115115
}
116116
}
117117
// An error may occur when the backend writes rows.
118-
if pnet.IsErrorPacket(data) {
118+
if pnet.IsErrorPacket(data[0]) {
119119
return cp.handleErrorPacket(data)
120120
}
121121
result.RowDatas = append(result.RowDatas, data)

pkg/proxy/backend/mock_client_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,11 @@ func (mc *mockClient) readUntilResultEnd(packetIO *pnet.PacketIO) (pkt []byte, e
279279
return
280280
}
281281
if mc.capability&pnet.ClientDeprecateEOF == 0 {
282-
if pnet.IsEOFPacket(pkt) {
282+
if pnet.IsEOFPacket(pkt[0], len(pkt)) {
283283
break
284284
}
285285
} else {
286-
if pnet.IsResultSetOKPacket(pkt) {
286+
if pnet.IsResultSetOKPacket(pkt[0], len(pkt)) {
287287
break
288288
}
289289
}

pkg/proxy/net/compress.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,24 @@ func (crw *compressedReadWriter) Read(p []byte) (n int, err error) {
108108
return
109109
}
110110

111+
func (crw *compressedReadWriter) ReadFrom(r io.Reader) (n int64, err error) {
112+
// TODO: copy compressed data directly.
113+
buf := make([]byte, DefaultConnBufferSize)
114+
nn := 0
115+
for {
116+
nn, err = r.Read(buf)
117+
if (err == nil || err == io.EOF) && nn > 0 {
118+
_, err = crw.Write(buf[:nn])
119+
n += int64(nn)
120+
}
121+
if err == io.EOF {
122+
return n, nil
123+
} else if err != nil {
124+
return n, err
125+
}
126+
}
127+
}
128+
111129
// Read and uncompress the data into readBuffer.
112130
// The format of the protocol: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression_packet.html
113131
func (crw *compressedReadWriter) readFromConn() error {
@@ -227,9 +245,7 @@ func (crw *compressedReadWriter) Peek(n int) (data []byte, err error) {
227245
return
228246
}
229247
}
230-
data = make([]byte, 0, n)
231-
copy(data, crw.readBuffer.Bytes())
232-
return
248+
return crw.readBuffer.Bytes()[:n], nil
233249
}
234250

235251
// Discard won't be used.

pkg/proxy/net/error.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ import (
88
)
99

1010
var (
11-
ErrExpectSSLRequest = errors.New("expect a SSLRequest packet")
12-
ErrReadConn = errors.New("failed to read the connection")
13-
ErrWriteConn = errors.New("failed to write the connection")
14-
ErrFlushConn = errors.New("failed to flush the connection")
15-
ErrCloseConn = errors.New("failed to close the connection")
16-
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
11+
ErrReadConn = errors.New("failed to read the connection")
12+
ErrWriteConn = errors.New("failed to write the connection")
13+
ErrRelayConn = errors.New("failed to relay the connection")
14+
ErrFlushConn = errors.New("failed to flush the connection")
15+
ErrCloseConn = errors.New("failed to close the connection")
16+
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
1717
)
1818

1919
// UserError is returned to the client.

pkg/proxy/net/mysql.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -412,24 +412,24 @@ func ParseErrorPacket(data []byte) error {
412412
}
413413

414414
// IsOKPacket returns true if it's an OK packet (but not ResultSet OK).
415-
func IsOKPacket(data []byte) bool {
416-
return data[0] == OKHeader.Byte()
415+
func IsOKPacket(firstByte byte) bool {
416+
return firstByte == OKHeader.Byte()
417417
}
418418

419419
// IsEOFPacket returns true if it's an EOF packet.
420-
func IsEOFPacket(data []byte) bool {
421-
return data[0] == EOFHeader.Byte() && len(data) <= 5
420+
func IsEOFPacket(firstByte byte, length int) bool {
421+
return firstByte == EOFHeader.Byte() && length <= 5
422422
}
423423

424424
// IsResultSetOKPacket returns true if it's an OK packet after the result set when CLIENT_DEPRECATE_EOF is enabled.
425425
// A row packet may also begin with 0xfe, so we need to judge it with the packet length.
426426
// See https://mariadb.com/kb/en/result-set-packets/
427-
func IsResultSetOKPacket(data []byte) bool {
427+
func IsResultSetOKPacket(firstByte byte, length int) bool {
428428
// With CLIENT_PROTOCOL_41 enabled, the least length is 7.
429-
return data[0] == EOFHeader.Byte() && len(data) >= 7 && len(data) < 0xFFFFFF
429+
return firstByte == EOFHeader.Byte() && length >= 7 && length < 0xFFFFFF
430430
}
431431

432432
// IsErrorPacket returns true if it's an error packet.
433-
func IsErrorPacket(data []byte) bool {
434-
return data[0] == ErrHeader.Byte()
433+
func IsErrorPacket(firstByte byte) bool {
434+
return firstByte == ErrHeader.Byte()
435435
}

pkg/proxy/net/packetio.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
package net
2626

2727
import (
28-
"bufio"
2928
"crypto/tls"
3029
"io"
3130
"net"
@@ -37,6 +36,7 @@ import (
3736
"github.com/pingcap/tiproxy/lib/util/errors"
3837
"github.com/pingcap/tiproxy/pkg/proxy/keepalive"
3938
"github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol"
39+
"github.com/pingcap/tiproxy/pkg/util/bufio"
4040
"go.uber.org/zap"
4141
)
4242

@@ -64,6 +64,7 @@ type packetReadWriter interface {
6464
Discard(n int) (int, error)
6565
Flush() error
6666
DirectWrite(p []byte) (int, error)
67+
ReadFrom(r io.Reader) (int64, error)
6768
Proxy() *proxyprotocol.Proxy
6869
TLSConnectionState() tls.ConnectionState
6970
InBytes() uint64
@@ -110,6 +111,12 @@ func (brw *basicReadWriter) Write(p []byte) (int, error) {
110111
return n, errors.WithStack(err)
111112
}
112113

114+
func (brw *basicReadWriter) ReadFrom(r io.Reader) (int64, error) {
115+
n, err := brw.ReadWriter.ReadFrom(r)
116+
brw.outBytes += uint64(n)
117+
return n, errors.WithStack(err)
118+
}
119+
113120
func (brw *basicReadWriter) DirectWrite(p []byte) (int, error) {
114121
n, err := brw.Conn.Write(p)
115122
brw.outBytes += uint64(n)
@@ -187,7 +194,8 @@ type PacketIO struct {
187194
lastKeepAlive config.KeepAlive
188195
rawConn net.Conn
189196
readWriter packetReadWriter
190-
header []byte
197+
limitReader io.LimitedReader // reuse memory to reduce allocation
198+
header []byte // reuse memory to reduce allocation
191199
logger *zap.Logger
192200
remoteAddr net.Addr
193201
wrap error
@@ -317,6 +325,42 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) {
317325
return nil
318326
}
319327

328+
func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, firstPktLen int) bool, process func(response []byte) error) error {
329+
p.readWriter.BeginRW(rwRead)
330+
dest.readWriter.BeginRW(rwWrite)
331+
p.limitReader.R = p.readWriter
332+
for {
333+
header, err := p.readWriter.Peek(5)
334+
if err != nil {
335+
return errors.Wrap(ErrReadConn, err)
336+
}
337+
length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16
338+
if isEnd(header[4], length) {
339+
// TODO: allocate a buffer from pool and return the buffer after `process`.
340+
data, err := p.ReadPacket()
341+
if err != nil {
342+
return errors.Wrap(ErrReadConn, err)
343+
}
344+
if err := dest.WritePacket(data, false); err != nil {
345+
return errors.Wrap(ErrWriteConn, err)
346+
}
347+
return process(data)
348+
} else {
349+
sequence, pktSequence := header[3], p.readWriter.Sequence()
350+
if sequence != pktSequence {
351+
return ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)
352+
}
353+
p.readWriter.SetSequence(sequence + 1)
354+
// Sequence may be different (e.g. with compression) so we can't just copy the data to the destination.
355+
dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1)
356+
p.limitReader.N = int64(length + 4)
357+
if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil {
358+
return errors.Wrap(ErrRelayConn, err)
359+
}
360+
}
361+
}
362+
}
363+
320364
func (p *PacketIO) InBytes() uint64 {
321365
return p.readWriter.InBytes()
322366
}

0 commit comments

Comments
 (0)