Skip to content

Commit 9bd6154

Browse files
authored
sqlreplay, lex: fix replaying with readonly causes inconsistent prepared statement IDs (#757)
1 parent dbd13a2 commit 9bd6154

File tree

7 files changed

+156
-140
lines changed

7 files changed

+156
-140
lines changed

lib/util/waitgroup/waitgroup.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func recoverFromErr(wg *sync.WaitGroup, recoverFn func(r interface{}), logger *z
4747
}()
4848
if r != nil && logger != nil {
4949
logger.Error("panic in the recoverable goroutine",
50-
zap.Reflect("r", r),
50+
zap.Any("err", r),
5151
zap.Stack("stack trace"))
5252
}
5353
// Call Done() before recoverFn because recoverFn normally calls `Close()`, which may call `wg.Wait()`.

pkg/sqlreplay/cmd/cmd.go

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"github.com/pingcap/tidb/pkg/parser"
1515
"github.com/pingcap/tiproxy/lib/util/errors"
1616
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
17-
"github.com/pingcap/tiproxy/pkg/util/lex"
1817
"github.com/siddontang/go/hack"
1918
)
2019

@@ -222,18 +221,6 @@ func (c *Command) QueryText() string {
222221
return ""
223222
}
224223

225-
func (c *Command) ReadOnly() bool {
226-
switch c.Type {
227-
case pnet.ComQuery, pnet.ComStmtPrepare:
228-
return lex.IsReadOnly(c.QueryText())
229-
case pnet.ComStmtExecute, pnet.ComStmtClose, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch:
230-
return lex.IsReadOnly(c.PreparedStmt)
231-
case pnet.ComCreateDB, pnet.ComDropDB, pnet.ComDelayedInsert:
232-
return false
233-
}
234-
return true
235-
}
236-
237224
func writeString(key, value string, writer *bytes.Buffer) error {
238225
var err error
239226
if _, err = writer.WriteString(key); err != nil {

pkg/sqlreplay/cmd/cmd_test.go

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -152,58 +152,3 @@ func TestDigest(t *testing.T) {
152152
require.Equal(t, cmd1.Digest(), cmd6.Digest())
153153
require.Equal(t, "select ?", cmd6.QueryText())
154154
}
155-
156-
func TestReadOnly(t *testing.T) {
157-
tests := []struct {
158-
cmd pnet.Command
159-
payload []byte
160-
prepareStmt string
161-
readOnly bool
162-
}{
163-
{
164-
cmd: pnet.ComQuery,
165-
payload: []byte("select 1"),
166-
readOnly: true,
167-
},
168-
{
169-
cmd: pnet.ComQuery,
170-
payload: []byte("insert into t value(1)"),
171-
readOnly: false,
172-
},
173-
{
174-
cmd: pnet.ComStmtPrepare,
175-
payload: []byte("select ?"),
176-
readOnly: true,
177-
},
178-
{
179-
cmd: pnet.ComStmtExecute,
180-
prepareStmt: "select ?",
181-
readOnly: true,
182-
},
183-
{
184-
cmd: pnet.ComStmtExecute,
185-
prepareStmt: "insert into t value(?)",
186-
readOnly: false,
187-
},
188-
{
189-
cmd: pnet.ComStmtExecute,
190-
readOnly: false,
191-
},
192-
{
193-
cmd: pnet.ComStmtClose,
194-
prepareStmt: "select ?",
195-
readOnly: true,
196-
},
197-
{
198-
cmd: pnet.ComQuit,
199-
readOnly: true,
200-
},
201-
}
202-
203-
for i, test := range tests {
204-
packet := append([]byte{byte(test.cmd)}, test.payload...)
205-
cmd := NewCommand(packet, time.Time{}, 100)
206-
cmd.PreparedStmt = test.prepareStmt
207-
require.Equal(t, test.readOnly, cmd.ReadOnly(), "case %d", i)
208-
}
209-
}

pkg/sqlreplay/conn/conn.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515
"github.com/pingcap/tiproxy/pkg/proxy/backend"
1616
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
1717
"github.com/pingcap/tiproxy/pkg/sqlreplay/cmd"
18+
"github.com/pingcap/tiproxy/pkg/util/lex"
19+
"github.com/siddontang/go/hack"
1820
"go.uber.org/zap"
1921
)
2022

@@ -107,8 +109,7 @@ func (c *conn) Run(ctx context.Context) {
107109
break
108110
}
109111
if c.readonly {
110-
c.updateCmdForExecuteStmt(command.Value)
111-
if !command.Value.ReadOnly() {
112+
if !c.isReadOnly(command.Value) {
112113
c.replayStats.FilteredCmds.Add(1)
113114
continue
114115
}
@@ -131,8 +132,30 @@ func (c *conn) Run(ctx context.Context) {
131132
}
132133
}
133134

134-
// updateCmdForExecuteStmt may be called multiple times, avoid duplicated works.
135+
func (c *conn) isReadOnly(command *cmd.Command) bool {
136+
switch command.Type {
137+
case pnet.ComQuery:
138+
return lex.IsReadOnly(hack.String(command.Payload[1:]))
139+
case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch:
140+
stmtID := binary.LittleEndian.Uint32(command.Payload[1:5])
141+
text, _, _ := c.backendConn.GetPreparedStmt(stmtID)
142+
if len(text) == 0 {
143+
c.lg.Error("prepared stmt not found", zap.Uint32("stmt_id", stmtID), zap.Stringer("cmd_type", command.Type))
144+
return false
145+
}
146+
return lex.IsReadOnly(text)
147+
case pnet.ComCreateDB, pnet.ComDropDB, pnet.ComDelayedInsert:
148+
return false
149+
}
150+
// Treat ComStmtPrepare and ComStmtClose as read-only to make prepared stmt IDs in capture and replay phases the same.
151+
// The problem is that it still requires write privilege. Better solutions are much more complex:
152+
// - Replace all prepared DML statements with `SELECT 1`, including ComStmtPrepare and `SET SESSION_STATES`.
153+
// - Remove all prepared DML statements and map catpure prepared stmt ID to replay prepared stmt ID, including ComStmtPrepare and `SET SESSION_STATES`.
154+
return true
155+
}
156+
135157
func (c *conn) updateCmdForExecuteStmt(command *cmd.Command) bool {
158+
// updated before
136159
if command.PreparedStmt != "" {
137160
return true
138161
}
@@ -147,7 +170,9 @@ func (c *conn) updateCmdForExecuteStmt(command *cmd.Command) bool {
147170
if command.Type == pnet.ComStmtExecute {
148171
_, args, _, err := pnet.ParseExecuteStmtRequest(command.Payload, paramNum, paramTypes)
149172
if err != nil {
150-
c.lg.Error("parsing ComExecuteStmt request failed", zap.Uint32("stmt_id", stmtID), zap.Error(err))
173+
// Failing to parse the request is not critical, so don't return false.
174+
c.lg.Error("parsing ComExecuteStmt request failed", zap.Uint32("stmt_id", stmtID), zap.String("sql", text),
175+
zap.Int("param_num", paramNum), zap.ByteString("param_types", paramTypes), zap.Error(err))
151176
}
152177
command.Params = args
153178
}

pkg/sqlreplay/conn/conn_test.go

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ func TestSkipReadOnly(t *testing.T) {
178178
},
179179
{
180180
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)},
181-
readonly: false,
181+
readonly: true,
182182
},
183183
{
184184
cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}},
@@ -221,3 +221,71 @@ func TestSkipReadOnly(t *testing.T) {
221221
cancel()
222222
wg.Wait()
223223
}
224+
225+
func TestReadOnly(t *testing.T) {
226+
tests := []struct {
227+
cmd pnet.Command
228+
stmt string
229+
readOnly bool
230+
}{
231+
{
232+
cmd: pnet.ComQuery,
233+
stmt: "select 1",
234+
readOnly: true,
235+
},
236+
{
237+
cmd: pnet.ComQuery,
238+
stmt: "insert into t value(1)",
239+
readOnly: false,
240+
},
241+
{
242+
cmd: pnet.ComStmtPrepare,
243+
stmt: "select ?",
244+
readOnly: true,
245+
},
246+
{
247+
cmd: pnet.ComStmtPrepare,
248+
stmt: "insert into t value(?)",
249+
readOnly: true,
250+
},
251+
{
252+
cmd: pnet.ComStmtExecute,
253+
stmt: "select ?",
254+
readOnly: true,
255+
},
256+
{
257+
cmd: pnet.ComStmtExecute,
258+
stmt: "insert into t value(?)",
259+
readOnly: false,
260+
},
261+
{
262+
cmd: pnet.ComStmtClose,
263+
stmt: "insert into t value(?)",
264+
readOnly: true,
265+
},
266+
{
267+
cmd: pnet.ComQuit,
268+
readOnly: true,
269+
},
270+
{
271+
cmd: pnet.ComCreateDB,
272+
readOnly: false,
273+
},
274+
}
275+
276+
conn := &conn{}
277+
backendConn := newMockBackendConn()
278+
conn.backendConn = backendConn
279+
for i, test := range tests {
280+
var payload []byte
281+
switch test.cmd {
282+
case pnet.ComQuery:
283+
payload = append([]byte{test.cmd.Byte()}, []byte(test.stmt)...)
284+
default:
285+
backendConn.prepared[1] = &preparedStmt{text: test.stmt}
286+
payload = []byte{test.cmd.Byte(), 1, 0, 0, 0}
287+
}
288+
command := cmd.NewCommand(payload, time.Time{}, 100)
289+
require.Equal(t, test.readOnly, conn.isReadOnly(command), "case %d", i)
290+
}
291+
}

pkg/util/lex/filter.go

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,75 +3,66 @@
33

44
package lex
55

6+
func startsWithKeyword(sql string, keywords [][]string) bool {
7+
lexer := NewLexer(sql)
8+
tokens := make([]string, 0, 2)
9+
for _, kw := range keywords {
10+
match := true
11+
for i, t := range kw {
12+
if len(tokens) <= i {
13+
tokens = append(tokens, lexer.NextToken())
14+
}
15+
if tokens[i] != t {
16+
match = false
17+
break
18+
}
19+
}
20+
if match {
21+
return true
22+
}
23+
}
24+
return false
25+
}
26+
627
var sensitiveKeywords = [][]string{
728
// contain passwords
8-
{
9-
"CREATE", "USER",
10-
},
11-
{
12-
"ALTER", "USER",
13-
},
14-
{
15-
"SET", "PASSWORD",
16-
},
17-
{
18-
"GRANT",
19-
},
29+
{"CREATE", "USER"},
30+
{"ALTER", "USER"},
31+
{"SET", "PASSWORD"},
32+
{"GRANT"},
2033
// contain cloud storage url
21-
{
22-
"BACKUP",
23-
},
24-
{
25-
"RESTORE",
26-
},
27-
{
28-
"IMPORT",
29-
},
34+
{"BACKUP"},
35+
{"RESTORE"},
36+
{"IMPORT"},
3037
// not supported yet
31-
{
32-
"LOAD", "DATA",
33-
},
38+
{"LOAD", "DATA"},
3439
}
3540

36-
// ignore prepared statements in text-protocol
37-
// ignore EXPLAIN (including EXPLAIN ANALYZE) and TRACE
38-
// include SELECT FOR UPDATE because it releases locks immediately in auto-commit transactions
39-
// include SET because SET SESSION_STATES and SET session variables should be executed
40-
var readOnlyKeywords = []string{
41-
"SELECT", "SHOW", "WITH", "SET", "USE", "DESC", "DESCRIBE", "TABLE", "DO",
41+
func IsSensitiveSQL(sql string) bool {
42+
return startsWithKeyword(sql, sensitiveKeywords)
4243
}
4344

44-
func IsSensitiveSQL(sql string) bool {
45-
lexer := NewLexer(sql)
46-
keyword := lexer.NextToken()
47-
if len(keyword) == 0 {
48-
return false
49-
}
50-
for _, kw := range sensitiveKeywords {
51-
if keyword != kw[0] {
52-
continue
53-
}
54-
if len(kw) <= 1 {
55-
return true
56-
}
57-
keyword = lexer.NextToken()
58-
if keyword == kw[1] {
59-
return true
60-
}
61-
}
62-
return false
45+
// ignore prepared statements in text-protocol
46+
// ignore EXPLAIN, EXPLAIN ANALYZE, and TRACE
47+
// include SELECT FOR UPDATE because it doesn't require write privilege
48+
// include SET because SET SESSION_STATES and SET session variables should be executed
49+
// include BEGIN / COMMIT in case the user sets autocommit to false, either in SET SESSION_STATES or SET @@autocommit
50+
var readOnlyKeywords = [][]string{
51+
{"SELECT"},
52+
{"SHOW"},
53+
{"WITH"},
54+
{"SET"},
55+
{"USE"},
56+
{"DESC"},
57+
{"DESCRIBE"},
58+
{"TABLE"},
59+
{"DO"},
60+
{"BEGIN"},
61+
{"COMMIT"},
62+
{"ROLLBACK"},
63+
{"START", "TRANSACTION"},
6364
}
6465

6566
func IsReadOnly(sql string) bool {
66-
lexer := NewLexer(sql)
67-
keyword := lexer.NextToken()
68-
if len(keyword) == 0 {
69-
return false
70-
}
71-
for _, kw := range readOnlyKeywords {
72-
if keyword == kw {
73-
return true
74-
}
75-
}
76-
return false
67+
return startsWithKeyword(sql, readOnlyKeywords)
7768
}

pkg/util/lex/filter_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ func TestSenstiveSQL(t *testing.T) {
2424
{`set`, false},
2525
}
2626

27-
for i, test := range tests {
28-
require.Equal(t, test.sensitive, IsSensitiveSQL(test.sql), "case %d", i)
27+
for _, test := range tests {
28+
require.Equal(t, test.sensitive, IsSensitiveSQL(test.sql), test.sql)
2929
}
3030
}
3131

@@ -52,11 +52,11 @@ func TestReadOnlySQL(t *testing.T) {
5252
{`do 1`, true},
5353
{`/*hello */select 1`, true},
5454
{` select 1`, true},
55-
{`/**/ start transaction`, false},
56-
{` COMMIT`, false},
55+
{`/**/ start transaction`, true},
56+
{` COMMIT`, true},
5757
}
5858

59-
for i, test := range tests {
60-
require.Equal(t, test.readOnly, IsReadOnly(test.sql), "case %d", i)
59+
for _, test := range tests {
60+
require.Equal(t, test.readOnly, IsReadOnly(test.sql), test.sql)
6161
}
6262
}

0 commit comments

Comments
 (0)