Skip to content

Commit 37c4701

Browse files
djshow832xhebox
andauthored
backend, net: reduce memory allocation in forwarding packets (#394)
Co-authored-by: xhe <[email protected]>
1 parent 9dac468 commit 37c4701

File tree

10 files changed

+220
-93
lines changed

10 files changed

+220
-93
lines changed

pkg/proxy/backend/backend_conn_mgr.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ func (mgr *BackendConnManager) initSessionStates(backendIO *pnet.PacketIO, sessi
353353

354354
func (mgr *BackendConnManager) querySessionStates(backendIO *pnet.PacketIO) (sessionStates, sessionToken string, err error) {
355355
// Do not lock here because the caller already locks.
356-
var result *gomysql.Result
356+
var result *gomysql.Resultset
357357
if result, _, err = mgr.cmdProcessor.query(backendIO, sqlQueryState); err != nil {
358358
return
359359
}

pkg/proxy/backend/cmd_processor.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ func NewCmdProcessor(logger *zap.Logger) *CmdProcessor {
3737
}
3838
}
3939

40-
func (cp *CmdProcessor) handleOKPacket(request, response []byte) *gomysql.Result {
41-
r := pnet.ParseOKPacket(response)
42-
cp.updateServerStatus(request, r.Status)
43-
return r
40+
func (cp *CmdProcessor) handleOKPacket(request, response []byte) uint16 {
41+
status := pnet.ParseOKPacket(response)
42+
cp.updateServerStatus(request, status)
43+
return status
4444
}
4545

4646
func (cp *CmdProcessor) handleErrorPacket(data []byte) error {

pkg/proxy/backend/cmd_processor_exec.go

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ func forwardOnePacket(destIO, srcIO *pnet.PacketIO, flush bool) (data []byte, er
9999

100100
func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO, request []byte) (uint16, error) {
101101
var serverStatus uint16
102-
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, length int) bool {
102+
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, length int) (end, needData bool) {
103103
switch {
104104
case pnet.IsErrorPacket(firstByte):
105-
return true
105+
return true, true
106106
case cp.capability&pnet.ClientDeprecateEOF == 0:
107-
return pnet.IsEOFPacket(firstByte, length)
107+
return pnet.IsEOFPacket(firstByte, length), true
108108
default:
109-
return pnet.IsResultSetOKPacket(firstByte, length)
109+
return pnet.IsResultSetOKPacket(firstByte, length), true
110110
}
111111
}, func(response []byte) error {
112112
switch {
@@ -119,7 +119,7 @@ func (cp *CmdProcessor) forwardUntilResultEnd(clientIO, backendIO *pnet.PacketIO
119119
serverStatus = cp.handleEOFPacket(request, response)
120120
return clientIO.Flush()
121121
default:
122-
serverStatus = cp.handleOKPacket(request, response).Status
122+
serverStatus = cp.handleOKPacket(request, response)
123123
return clientIO.Flush()
124124
}
125125
})
@@ -146,9 +146,14 @@ func (cp *CmdProcessor) forwardPrepareCmd(clientIO, backendIO *pnet.PacketIO) er
146146
expectedPackets++
147147
}
148148
}
149-
for i := 0; i < expectedPackets; i++ {
150-
// Ignore this status because PREPARE doesn't affect status.
151-
if _, err = forwardOnePacket(clientIO, backendIO, false); err != nil {
149+
// Ignore this status because PREPARE doesn't affect status.
150+
if expectedPackets > 0 {
151+
i := 0
152+
err = backendIO.ForwardUntil(clientIO, func(firstByte byte, firstPktLen int) (end, needData bool) {
153+
i++
154+
return i >= expectedPackets, false
155+
}, nil)
156+
if err != nil {
152157
return err
153158
}
154159
}
@@ -175,26 +180,35 @@ func (cp *CmdProcessor) forwardFieldListCmd(clientIO, backendIO *pnet.PacketIO,
175180

176181
func (cp *CmdProcessor) forwardQueryCmd(clientIO, backendIO *pnet.PacketIO, request []byte) error {
177182
for {
178-
response, err := forwardOnePacket(clientIO, backendIO, false)
179-
if err != nil {
180-
return err
181-
}
182183
var serverStatus uint16
183-
switch response[0] {
184-
case mysql.OKHeader:
185-
rs := cp.handleOKPacket(request, response)
186-
serverStatus, err = rs.Status, clientIO.Flush()
187-
case mysql.ErrHeader:
188-
if err := clientIO.Flush(); err != nil {
189-
return err
184+
var first byte
185+
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, _ int) (end, needData bool) {
186+
first = firstByte
187+
switch firstByte {
188+
case mysql.OKHeader, mysql.ErrHeader:
189+
return true, true
190+
default:
191+
return true, false
190192
}
191-
// Subsequent statements won't be executed even if it's a multi-statement.
192-
return cp.handleErrorPacket(response)
193-
case mysql.LocalInFileHeader:
194-
serverStatus, err = cp.forwardLoadInFile(clientIO, backendIO, request)
195-
default:
196-
serverStatus, err = cp.forwardResultSet(clientIO, backendIO, request)
197-
}
193+
}, func(response []byte) error {
194+
var err error
195+
switch first {
196+
case mysql.OKHeader:
197+
status := cp.handleOKPacket(request, response)
198+
serverStatus, err = status, clientIO.Flush()
199+
case mysql.ErrHeader:
200+
if err = clientIO.Flush(); err != nil {
201+
return err
202+
}
203+
// Subsequent statements won't be executed even if it's a multi-statement.
204+
return cp.handleErrorPacket(response)
205+
case mysql.LocalInFileHeader:
206+
serverStatus, err = cp.forwardLoadInFile(clientIO, backendIO, request)
207+
default:
208+
serverStatus, err = cp.forwardResultSet(clientIO, backendIO, request)
209+
}
210+
return err
211+
})
198212
if err != nil {
199213
return err
200214
}
@@ -213,11 +227,14 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re
213227
// The client sends file data until an empty packet.
214228
for {
215229
var data []byte
216-
// The file may be large, so always flush it.
217-
if data, err = forwardOnePacket(backendIO, clientIO, true); err != nil {
230+
// Do not call PacketIO.ForwardUntil. It peeks 5 bytes but there may be only 4 bytes here.
231+
if data, err = forwardOnePacket(backendIO, clientIO, false); err != nil {
218232
return
219233
}
220234
if len(data) == 0 {
235+
if err := backendIO.Flush(); err != nil {
236+
return 0, err
237+
}
221238
break
222239
}
223240
}
@@ -227,8 +244,7 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re
227244
}
228245
switch response[0] {
229246
case mysql.OKHeader:
230-
rs := cp.handleOKPacket(request, response)
231-
return rs.Status, nil
247+
return cp.handleOKPacket(request, response), nil
232248
case mysql.ErrHeader:
233249
return serverStatus, cp.handleErrorPacket(response)
234250
}
@@ -238,22 +254,22 @@ func (cp *CmdProcessor) forwardLoadInFile(clientIO, backendIO *pnet.PacketIO, re
238254

239255
func (cp *CmdProcessor) forwardResultSet(clientIO, backendIO *pnet.PacketIO, request []byte) (uint16, error) {
240256
if cp.capability&pnet.ClientDeprecateEOF == 0 {
241-
var response []byte
257+
var serverStatus uint16
242258
// read columns
243-
for {
244-
var err error
245-
if response, err = forwardOnePacket(clientIO, backendIO, false); err != nil {
246-
return 0, err
247-
}
248-
if pnet.IsEOFPacket(response[0], len(response)) {
249-
break
259+
err := backendIO.ForwardUntil(clientIO, func(firstByte byte, firstPktLen int) (end, needData bool) {
260+
return pnet.IsEOFPacket(firstByte, firstPktLen), true
261+
}, func(response []byte) error {
262+
serverStatus = binary.LittleEndian.Uint16(response[3:])
263+
// If a cursor exists, only columns are sent this time. The client will then send COM_STMT_FETCH to fetch rows.
264+
// Otherwise, columns and rows are both sent once.
265+
if serverStatus&mysql.ServerStatusCursorExists > 0 {
266+
serverStatus = cp.handleEOFPacket(request, response)
267+
return clientIO.Flush()
250268
}
251-
}
252-
serverStatus := binary.LittleEndian.Uint16(response[3:])
253-
// If a cursor exists, only columns are sent this time. The client will then send COM_STMT_FETCH to fetch rows.
254-
// Otherwise, columns and rows are both sent once.
255-
if serverStatus&mysql.ServerStatusCursorExists > 0 {
256-
return cp.handleEOFPacket(request, response), clientIO.Flush()
269+
return nil
270+
})
271+
if err != nil {
272+
return serverStatus, err
257273
}
258274
}
259275
// Deprecate EOF or no cursor.

pkg/proxy/backend/cmd_processor_query.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
// query is called when the proxy sends requests to the backend by itself,
1717
// such as querying session states, committing the current transaction.
1818
// It only supports limited cases, excluding loading file, cursor fetch, multi-statements, etc.
19-
func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomysql.Result, response []byte, err error) {
19+
func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomysql.Resultset, response []byte, err error) {
2020
// send request
2121
packetIO.ResetSequence()
2222
data := hack.Slice(sql)
@@ -33,13 +33,15 @@ func (cp *CmdProcessor) query(packetIO *pnet.PacketIO, sql string) (result *gomy
3333
}
3434
switch response[0] {
3535
case mysql.OKHeader:
36-
result = cp.handleOKPacket(request, response)
36+
cp.handleOKPacket(request, response)
3737
case mysql.ErrHeader:
3838
err = cp.handleErrorPacket(response)
3939
case mysql.LocalInFileHeader:
4040
err = errors.WithStack(mysql.ErrMalformPacket)
4141
default:
42-
result, err = cp.readResultSet(packetIO, response)
42+
var rs *gomysql.Result
43+
rs, err = cp.readResultSet(packetIO, response)
44+
result = rs.Resultset
4345
}
4446
return
4547
}
@@ -109,8 +111,7 @@ func (cp *CmdProcessor) readResultRows(packetIO *pnet.PacketIO, result *gomysql.
109111
}
110112
} else {
111113
if pnet.IsResultSetOKPacket(data[0], len(data)) {
112-
rs := pnet.ParseOKPacket(data)
113-
result.Status = rs.Status
114+
result.Status = pnet.ParseOKPacket(data)
114115
break
115116
}
116117
}

pkg/proxy/backend/mock_client_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ func (mc *mockClient) readResultSet(packetIO *pnet.PacketIO) error {
361361
if mc.capability&pnet.ClientDeprecateEOF == 0 {
362362
serverStatus = binary.LittleEndian.Uint16(pkt[3:])
363363
} else {
364-
rs := pnet.ParseOKPacket(pkt)
365-
serverStatus = rs.Status
364+
serverStatus = pnet.ParseOKPacket(pkt)
366365
}
367366
}
368367
if serverStatus&mysql.ServerMoreResultsExists == 0 {

pkg/proxy/backend/mock_proxy_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type mockProxy struct {
3939

4040
*proxyConfig
4141
// outputs that received from the server.
42-
rs *gomysql.Result
42+
rs *gomysql.Resultset
4343
// execution results
4444
err error
4545
logger *zap.Logger

pkg/proxy/net/mysql.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -385,17 +385,15 @@ func WriteServerVersion(conn net.Conn, serverVersion string) error {
385385
return c.WritePacket(data)
386386
}
387387

388-
// ParseOKPacket transforms an OK packet into a Result object.
389-
func ParseOKPacket(data []byte) *gomysql.Result {
390-
var n int
388+
// ParseOKPacket parses an OK packet and only returns server status.
389+
func ParseOKPacket(data []byte) uint16 {
391390
var pos = 1
392-
r := new(gomysql.Result)
393-
r.AffectedRows, _, n = ParseLengthEncodedInt(data[pos:])
394-
pos += n
395-
r.InsertId, _, n = ParseLengthEncodedInt(data[pos:])
396-
pos += n
397-
r.Status = binary.LittleEndian.Uint16(data[pos:])
398-
return r
391+
// skip affected rows
392+
pos += SkipLengthEncodedInt(data[pos:])
393+
// skip insert id
394+
pos += SkipLengthEncodedInt(data[pos:])
395+
// return status
396+
return binary.LittleEndian.Uint16(data[pos:])
399397
}
400398

401399
// ParseErrorPacket transforms an error packet into a MyError object.

pkg/proxy/net/packetio.go

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,8 @@ func (p *PacketIO) WritePacket(data []byte, flush bool) (err error) {
325325
return nil
326326
}
327327

328-
func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, firstPktLen int) bool, process func(response []byte) error) error {
328+
func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, firstPktLen int) (end, needData bool),
329+
process func(response []byte) error) error {
329330
p.readWriter.BeginRW(rwRead)
330331
dest.readWriter.BeginRW(rwWrite)
331332
p.limitReader.R = p.readWriter
@@ -335,28 +336,48 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first
335336
return errors.Wrap(ErrReadConn, err)
336337
}
337338
length := int(header[0]) | int(header[1])<<8 | int(header[2])<<16
338-
if isEnd(header[4], length) {
339+
end, needData := isEnd(header[4], length)
340+
var data []byte
341+
// Just call ReadFrom if the caller doesn't need the data, even if it's the last packet.
342+
if end && needData {
339343
// TODO: allocate a buffer from pool and return the buffer after `process`.
340-
data, err := p.ReadPacket()
344+
data, err = p.ReadPacket()
341345
if err != nil {
342346
return errors.Wrap(ErrReadConn, err)
343347
}
344348
if err := dest.WritePacket(data, false); err != nil {
345349
return errors.Wrap(ErrWriteConn, err)
346350
}
347-
return process(data)
348351
} else {
349-
sequence, pktSequence := header[3], p.readWriter.Sequence()
350-
if sequence != pktSequence {
351-
return ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)
352+
for {
353+
sequence, pktSequence := header[3], p.readWriter.Sequence()
354+
if sequence != pktSequence {
355+
return ErrInvalidSequence.GenWithStack("invalid sequence, expected %d, actual %d", pktSequence, sequence)
356+
}
357+
p.readWriter.SetSequence(sequence + 1)
358+
// Sequence may be different (e.g. with compression) so we can't just copy the data to the destination.
359+
dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1)
360+
p.limitReader.N = int64(length + 4)
361+
if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil {
362+
return errors.Wrap(ErrRelayConn, err)
363+
}
364+
// For large packets, continue.
365+
if length < MaxPayloadLen {
366+
break
367+
}
368+
if header, err = p.readWriter.Peek(4); err != nil {
369+
return errors.Wrap(ErrReadConn, err)
370+
}
371+
length = int(header[0]) | int(header[1])<<8 | int(header[2])<<16
352372
}
353-
p.readWriter.SetSequence(sequence + 1)
354-
// Sequence may be different (e.g. with compression) so we can't just copy the data to the destination.
355-
dest.readWriter.SetSequence(dest.readWriter.Sequence() + 1)
356-
p.limitReader.N = int64(length + 4)
357-
if _, err := dest.readWriter.ReadFrom(&p.limitReader); err != nil {
358-
return errors.Wrap(ErrRelayConn, err)
373+
}
374+
375+
if end {
376+
if process != nil {
377+
// data == nil iff needData == false
378+
return process(data)
359379
}
380+
return nil
360381
}
361382
}
362383
}

0 commit comments

Comments
 (0)