Skip to content

Commit bc576f7

Browse files
authored
sqlreplay, net: maintain the mapping of prepared statement ID and text (#679)
1 parent 1048614 commit bc576f7

File tree

14 files changed

+491
-28
lines changed

14 files changed

+491
-28
lines changed

pkg/proxy/net/mysql.go

Lines changed: 172 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package net
66
import (
77
"bytes"
88
"encoding/binary"
9+
"encoding/json"
910
"math"
1011
"net"
1112

@@ -583,14 +584,44 @@ func MakeQueryPacket(stmt string) []byte {
583584
return request
584585
}
585586

586-
func MakePrepareStmtPacket(stmt string) []byte {
587+
func MakePrepareStmtRequest(stmt string) []byte {
587588
request := make([]byte, len(stmt)+1)
588589
request[0] = ComStmtPrepare.Byte()
589590
copy(request[1:], hack.Slice(stmt))
590591
return request
591592
}
592593

593-
func MakeExecuteStmtPacket(stmtID uint32, args []any) ([]byte, error) {
594+
// MakePrepareStmtResp creates a prepared statement response.
595+
// The packet is incomplete and it's only used for testing.
596+
func MakePrepareStmtResp(stmtID uint32, paramNum int) []byte {
597+
// header
598+
response := make([]byte, 1+4+2+2)
599+
pos := 0
600+
response[pos] = ComStmtPrepare.Byte()
601+
pos += 1
602+
binary.LittleEndian.PutUint32(response[pos:], stmtID)
603+
pos += 4
604+
// column count
605+
pos += 2
606+
// param count
607+
binary.LittleEndian.PutUint16(response[pos:], uint16(paramNum))
608+
// ignore rest part
609+
return response
610+
}
611+
612+
func ParsePrepareStmtResp(resp []byte) (stmtID uint32, paramNum int) {
613+
// header
614+
pos := 1
615+
stmtID = binary.LittleEndian.Uint32(resp[pos:])
616+
pos += 4
617+
// column count
618+
pos += 2
619+
paramNum = int(binary.LittleEndian.Uint16(resp[pos:]))
620+
// ignore rest part
621+
return
622+
}
623+
624+
func MakeExecuteStmtRequest(stmtID uint32, args []any) ([]byte, error) {
594625
paramNum := len(args)
595626
paramTypes := make([]byte, paramNum*2)
596627
paramValues := make([][]byte, paramNum)
@@ -607,21 +638,64 @@ func MakeExecuteStmtPacket(stmtID uint32, args []any) ([]byte, error) {
607638

608639
newParamBoundFlag = 1
609640
switch v := args[i].(type) {
641+
case int8:
642+
paramTypes[i<<1] = fieldTypeTiny
643+
paramValues[i] = []byte{byte(v)}
644+
case int16:
645+
paramTypes[i<<1] = fieldTypeShort
646+
paramValues[i] = Uint16ToBytes(uint16(v))
647+
case int32:
648+
paramTypes[i<<1] = fieldTypeLong
649+
paramValues[i] = Uint32ToBytes(uint32(v))
650+
case int:
651+
paramTypes[i<<1] = fieldTypeLongLong
652+
paramValues[i] = Uint64ToBytes(uint64(v))
610653
case int64:
611654
paramTypes[i<<1] = fieldTypeLongLong
612655
paramValues[i] = Uint64ToBytes(uint64(v))
656+
case uint8:
657+
paramTypes[i<<1] = fieldTypeTiny
658+
paramTypes[(i<<1)+1] = 0x80
659+
paramValues[i] = []byte{v}
660+
case uint16:
661+
paramTypes[i<<1] = fieldTypeShort
662+
paramTypes[(i<<1)+1] = 0x80
663+
paramValues[i] = Uint16ToBytes(v)
664+
case uint32:
665+
paramTypes[i<<1] = fieldTypeLong
666+
paramTypes[(i<<1)+1] = 0x80
667+
paramValues[i] = Uint32ToBytes(v)
668+
case uint:
669+
paramTypes[i<<1] = fieldTypeLongLong
670+
paramTypes[(i<<1)+1] = 0x80
671+
paramValues[i] = Uint64ToBytes(uint64(v))
613672
case uint64:
614673
paramTypes[i<<1] = fieldTypeLongLong
615674
paramTypes[(i<<1)+1] = 0x80
616675
paramValues[i] = Uint64ToBytes(v)
676+
case bool:
677+
paramTypes[i<<1] = fieldTypeTiny
678+
if v {
679+
paramValues[i] = []byte{1}
680+
} else {
681+
paramValues[i] = []byte{0}
682+
}
683+
case float32:
684+
paramTypes[i<<1] = fieldTypeFloat
685+
paramValues[i] = Uint32ToBytes(math.Float32bits(v))
617686
case float64:
618687
paramTypes[i<<1] = fieldTypeDouble
619688
paramValues[i] = Uint64ToBytes(math.Float64bits(v))
620689
case string:
621690
paramTypes[i<<1] = fieldTypeString
622691
paramValues[i] = DumpLengthEncodedString(nil, hack.Slice(v))
692+
case []byte:
693+
paramTypes[i<<1] = fieldTypeString
694+
paramValues[i] = DumpLengthEncodedString(nil, v)
695+
case json.RawMessage:
696+
paramTypes[i<<1] = fieldTypeString
697+
paramValues[i] = DumpLengthEncodedString(nil, v)
623698
default:
624-
// we don't need other types currently
625699
return nil, errors.WithStack(errors.Errorf("unsupported type %T", v))
626700
}
627701

@@ -656,3 +730,98 @@ func MakeExecuteStmtPacket(stmtID uint32, args []any) ([]byte, error) {
656730
}
657731
return request, nil
658732
}
733+
734+
// ParseExecuteStmtRequest parses ComStmtExecute request.
735+
// NOTICE: the type of returned args may be wrong because it doesn't have the knowledge of real param types.
736+
// E.g. []byte is returned as string, and int is returned as int32.
737+
func ParseExecuteStmtRequest(data []byte, paramNum int) (stmtID uint32, args []any, err error) {
738+
if len(data) < 1+4+1+4+1 {
739+
return 0, nil, errors.WithStack(gomysql.ErrMalformPacket)
740+
}
741+
742+
pos := 1
743+
stmtID = binary.LittleEndian.Uint32(data[pos : pos+4])
744+
// cursor flag and iteration count
745+
pos += 4 + 1 + 4
746+
if len(data) < pos+((paramNum+7)>>3)+1 {
747+
return 0, nil, errors.WithStack(gomysql.ErrMalformPacket)
748+
}
749+
nullBitmap := data[pos : pos+((paramNum+7)>>3)]
750+
pos += len(nullBitmap)
751+
newParamBoundFlag := data[pos]
752+
pos += 1
753+
args = make([]any, paramNum)
754+
if newParamBoundFlag == 0 {
755+
return stmtID, args, nil
756+
}
757+
758+
if len(data) < pos+paramNum*2 {
759+
return 0, nil, errors.WithStack(gomysql.ErrMalformPacket)
760+
}
761+
paramTypes := data[pos : pos+paramNum*2]
762+
pos += paramNum * 2
763+
764+
for i := 0; i < paramNum; i++ {
765+
if nullBitmap[i/8]&(1<<(uint(i)%8)) > 0 {
766+
args[i] = nil
767+
continue
768+
}
769+
switch paramTypes[i<<1] {
770+
case fieldTypeTiny:
771+
if paramTypes[(i<<1)+1] == 0x80 {
772+
args[i] = uint8(data[pos])
773+
} else {
774+
args[i] = int8(data[pos])
775+
}
776+
pos += 1
777+
case fieldTypeShort:
778+
v := binary.LittleEndian.Uint16(data[pos : pos+2])
779+
if paramTypes[(i<<1)+1] == 0x80 {
780+
args[i] = v
781+
} else {
782+
args[i] = int16(v)
783+
}
784+
pos += 2
785+
case fieldTypeLong:
786+
v := binary.LittleEndian.Uint32(data[pos : pos+4])
787+
if paramTypes[(i<<1)+1] == 0x80 {
788+
args[i] = v
789+
} else {
790+
args[i] = int32(v)
791+
}
792+
pos += 4
793+
case fieldTypeLongLong:
794+
v := binary.LittleEndian.Uint64(data[pos : pos+8])
795+
if paramTypes[(i<<1)+1] == 0x80 {
796+
args[i] = v
797+
} else {
798+
args[i] = int64(v)
799+
}
800+
pos += 8
801+
case fieldTypeFloat:
802+
args[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
803+
pos += 4
804+
case fieldTypeDouble:
805+
args[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8]))
806+
pos += 8
807+
case fieldTypeString:
808+
v, _, off, err := ParseLengthEncodedBytes(data[pos:])
809+
if err != nil {
810+
return 0, nil, errors.Wrapf(err, "parse param %d err", i)
811+
}
812+
args[i] = hack.String(v)
813+
pos += off
814+
default:
815+
return 0, nil, errors.Errorf("unsupported type %d", paramTypes[i<<1])
816+
}
817+
}
818+
819+
return stmtID, args, nil
820+
}
821+
822+
func MakeCloseStmtRequest(stmtID uint32) []byte {
823+
request := make([]byte, 1+4)
824+
request[0] = ComStmtClose.Byte()
825+
binary.LittleEndian.PutUint32(request[1:], stmtID)
826+
return request
827+
}

pkg/proxy/net/mysql_test.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,32 @@ func TestCheckSqlPort(t *testing.T) {
113113

114114
func TestPrepareStmts(t *testing.T) {
115115
args := []any{
116+
nil,
116117
"hello",
117-
uint64(1),
118-
int64(1),
119-
float64(1),
118+
byte(10),
119+
int16(-100),
120+
int32(-200),
121+
int64(-300),
122+
uint16(100),
123+
uint32(200),
124+
uint64(300),
125+
float32(1.1),
126+
float64(1.2),
127+
nil,
120128
}
121129

122-
b := MakePrepareStmtPacket("select ?")
130+
b := MakePrepareStmtRequest("select ?")
123131
require.Len(t, b, len("select ?")+1)
124132

125-
_, err := MakeExecuteStmtPacket(1, args)
133+
data1, err := MakeExecuteStmtRequest(1, args)
126134
require.NoError(t, err)
135+
136+
stmtID, pArgs, err := ParseExecuteStmtRequest(data1, len(args))
137+
require.NoError(t, err)
138+
require.Equal(t, uint32(1), stmtID)
139+
require.EqualValues(t, args, pArgs)
140+
141+
data2, err := MakeExecuteStmtRequest(1, pArgs)
142+
require.NoError(t, err)
143+
require.Equal(t, data1, data2)
127144
}

pkg/proxy/net/protocol.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,22 @@ func DumpUint16(buffer []byte, n uint16) []byte {
154154
return buffer
155155
}
156156

157+
func Uint16ToBytes(n uint16) []byte {
158+
return []byte{
159+
byte(n),
160+
byte(n >> 8),
161+
}
162+
}
163+
164+
func Uint32ToBytes(n uint32) []byte {
165+
return []byte{
166+
byte(n),
167+
byte(n >> 8),
168+
byte(n >> 16),
169+
byte(n >> 24),
170+
}
171+
}
172+
157173
func Uint64ToBytes(n uint64) []byte {
158174
return []byte{
159175
byte(n),

pkg/sqlreplay/cmd/cmd.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package cmd
55

66
import (
77
"bytes"
8+
"fmt"
89
"io"
910
"strconv"
1011
"strings"
@@ -33,8 +34,10 @@ type LineReader interface {
3334
}
3435

3536
type Command struct {
37+
PreparedStmt string
38+
Params []any
39+
digest string
3640
// Payload starts with command type so that replay can reuse this byte array.
37-
digest string
3841
Payload []byte
3942
StartTs time.Time
4043
ConnID uint64
@@ -191,23 +194,26 @@ func (c *Command) Decode(reader LineReader) error {
191194
}
192195

193196
func (c *Command) Digest() string {
194-
if c.digest == "" {
195-
// TODO: ComStmtExecute
197+
if len(c.digest) == 0 {
196198
switch c.Type {
197199
case pnet.ComQuery, pnet.ComStmtPrepare:
198200
stmt := hack.String(c.Payload[1:])
199201
_, digest := parser.NormalizeDigest(stmt)
200202
c.digest = digest.String()
203+
case pnet.ComStmtExecute:
204+
_, digest := parser.NormalizeDigest(c.PreparedStmt)
205+
c.digest = digest.String()
201206
}
202207
}
203208
return c.digest
204209
}
205210

206211
func (c *Command) QueryText() string {
207-
// TODO: ComStmtExecute
208212
switch c.Type {
209213
case pnet.ComQuery, pnet.ComStmtPrepare:
210214
return hack.String(c.Payload[1:])
215+
case pnet.ComStmtExecute:
216+
return fmt.Sprintf("%s params=%v", c.PreparedStmt, c.Params)
211217
}
212218
return ""
213219
}

pkg/sqlreplay/cmd/cmd_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,16 @@ func TestDigest(t *testing.T) {
134134

135135
cmd3 := NewCommand(append([]byte{pnet.ComFieldList.Byte()}, []byte("xxx")...), time.Now(), 100)
136136
require.Empty(t, cmd3.Digest())
137+
138+
cmd4 := NewCommand(append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...), time.Now(), 100)
139+
require.Equal(t, cmd1.Digest(), cmd4.Digest())
140+
require.Equal(t, "select ?", cmd4.QueryText())
141+
142+
data, err := pnet.MakeExecuteStmtRequest(1, []any{1})
143+
require.NoError(t, err)
144+
cmd5 := NewCommand(data, time.Now(), 100)
145+
cmd5.PreparedStmt = "select ?"
146+
cmd5.Params = []any{1}
147+
require.Equal(t, cmd1.Digest(), cmd5.Digest())
148+
require.Equal(t, "select ? params=[1]", cmd5.QueryText())
137149
}

0 commit comments

Comments
 (0)