Skip to content

Commit 4e2dca5

Browse files
authored
capture: do not capture sensitive commands (#682)
1 parent bc576f7 commit 4e2dca5

File tree

8 files changed

+289
-10
lines changed

8 files changed

+289
-10
lines changed

pkg/sqlreplay/capture/capture.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,15 @@ func (c *capture) putCommand(command *cmd.Command) bool {
313313
return false
314314
}
315315
case pnet.ComChangeUser:
316-
// COM_CHANGE_USER sends auth data, so ignore it.
317-
return false
316+
// COM_CHANGE_USER sends auth data, change it to COM_RESET_CONNECTION.
317+
command.Type = pnet.ComResetConnection
318+
command.Payload = []byte{pnet.ComResetConnection.Byte()}
319+
case pnet.ComQuery:
320+
// Avoid password leakage.
321+
if IsSensitiveSQL(hack.String(command.Payload[1:])) {
322+
c.filteredCmds++
323+
return false
324+
}
318325
}
319326
select {
320327
case c.cmdCh <- command:

pkg/sqlreplay/capture/capture_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,58 @@ func TestQuit(t *testing.T) {
265265
require.Equal(t, 1, strings.Count(data, "# Cmd_type: Quit"))
266266
require.Equal(t, uint64(3), cpt.capturedCmds)
267267
}
268+
269+
func TestFilterCmds(t *testing.T) {
270+
tests := []struct {
271+
packet []byte
272+
want string
273+
notWant string
274+
}{
275+
{
276+
packet: pnet.MakeChangeUser(&pnet.ChangeUserReq{
277+
User: "root",
278+
DB: "test",
279+
}, 0),
280+
want: pnet.ComResetConnection.String(),
281+
notWant: pnet.ComChangeUser.String(),
282+
},
283+
{
284+
packet: append([]byte{pnet.ComQuery.Byte()}, []byte("CREATE USER u1 IDENTIFIED BY '123456'")...),
285+
notWant: "123456",
286+
},
287+
{
288+
packet: append([]byte{pnet.ComQuery.Byte()}, []byte("select 1")...),
289+
want: "select 1",
290+
},
291+
}
292+
293+
cfg := CaptureConfig{
294+
Output: t.TempDir(),
295+
Duration: 10 * time.Second,
296+
}
297+
for i, test := range tests {
298+
cpt := NewCapture(zap.NewNop())
299+
writer := newMockWriter(store.WriterCfg{})
300+
cfg.cmdLogger = writer
301+
require.NoError(t, cpt.Start(cfg))
302+
cpt.Capture(test.packet, time.Now(), 100, func() (string, error) {
303+
return "init session 100", nil
304+
})
305+
cpt.Stop(nil)
306+
307+
data := string(writer.getData())
308+
if len(test.want) > 0 {
309+
require.Equal(t, 1, strings.Count(data, test.want), "case %d", i)
310+
require.Equal(t, uint64(2), cpt.capturedCmds, "case %d", i)
311+
require.Equal(t, uint64(0), cpt.filteredCmds, "case %d", i)
312+
} else {
313+
require.Equal(t, uint64(1), cpt.capturedCmds, "case %d", i)
314+
require.Equal(t, uint64(1), cpt.filteredCmds, "case %d", i)
315+
}
316+
if len(test.notWant) > 0 {
317+
require.Equal(t, 0, strings.Count(data, test.notWant), "case %d", i)
318+
}
319+
320+
cpt.Close()
321+
}
322+
}

pkg/sqlreplay/capture/filter.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package capture
5+
6+
import (
7+
"github.com/pingcap/tiproxy/pkg/util/lex"
8+
)
9+
10+
var sensitiveKeywords = [][]string{
11+
// contain passwords
12+
{
13+
"CREATE", "USER",
14+
},
15+
{
16+
"ALTER", "USER",
17+
},
18+
{
19+
"SET", "PASSWORD",
20+
},
21+
{
22+
"GRANT",
23+
},
24+
// contain cloud storage url
25+
{
26+
"BACKUP",
27+
},
28+
{
29+
"RESTORE",
30+
},
31+
{
32+
"IMPORT",
33+
},
34+
// not supported yet
35+
{
36+
"LOAD", "DATA",
37+
},
38+
}
39+
40+
func IsSensitiveSQL(sql string) bool {
41+
lexer := lex.NewLexer(sql)
42+
keyword := lexer.NextToken()
43+
if len(keyword) == 0 {
44+
return false
45+
}
46+
for _, kw := range sensitiveKeywords {
47+
if keyword != kw[0] {
48+
continue
49+
}
50+
if len(kw) <= 1 {
51+
return true
52+
}
53+
keyword = lexer.NextToken()
54+
if keyword == kw[1] {
55+
return true
56+
}
57+
}
58+
return false
59+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package capture
5+
6+
import (
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestSenstiveSQL(t *testing.T) {
13+
tests := []struct {
14+
sql string
15+
sensitive bool
16+
}{
17+
{`SELECT * FROM table_name`, false},
18+
{`grant ALL PRIVILEGES ON database_name.* TO 'username'@'localhost' IDENTIFIED BY 'password'`, true},
19+
{`CREATE USER 'new_user'@'localhost' IDENTIFIED BY 'secure_password';`, true},
20+
{` ALTER USER 'existing_user'@'localhost' IDENTIFIED BY 'new_password'`, true},
21+
{`/*hello */set PASSWORD FOR 'username'@'localhost' = PASSWORD('new_password');`, true},
22+
{`set global anything = 'hello' `, false},
23+
{``, false},
24+
{`set`, false},
25+
}
26+
27+
for i, test := range tests {
28+
require.Equal(t, test.sensitive, IsSensitiveSQL(test.sql), "case %d", i)
29+
}
30+
}

pkg/sqlreplay/replay/replay.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/pingcap/tiproxy/lib/util/errors"
1616
"github.com/pingcap/tiproxy/lib/util/waitgroup"
1717
"github.com/pingcap/tiproxy/pkg/proxy/backend"
18-
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
1918
"github.com/pingcap/tiproxy/pkg/sqlreplay/cmd"
2019
"github.com/pingcap/tiproxy/pkg/sqlreplay/conn"
2120
"github.com/pingcap/tiproxy/pkg/sqlreplay/report"
@@ -169,12 +168,6 @@ func (r *replay) readCommands(ctx context.Context) {
169168
r.Stop(err)
170169
break
171170
}
172-
// Replayer always uses the same username. It has no passwords for other users.
173-
// TODO: clear the session states.
174-
if command.Type == pnet.ComChangeUser {
175-
r.filteredCmds++
176-
continue
177-
}
178171
if captureStartTs.IsZero() {
179172
// first command
180173
captureStartTs = command.StartTs

pkg/sqlreplay/replay/replay_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func TestProgress(t *testing.T) {
168168
}
169169
defer loader.Close()
170170

171-
cmdCh := make(chan *cmd.Command, 10)
171+
cmdCh := make(chan *cmd.Command)
172172
replay := NewReplay(zap.NewNop())
173173
defer replay.Close()
174174
cfg := ReplayConfig{

pkg/util/lex/lex.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package lex
5+
6+
type Lexer struct {
7+
sql string
8+
curToken []byte
9+
curIdx int
10+
}
11+
12+
func NewLexer(sql string) *Lexer {
13+
return &Lexer{
14+
sql: sql,
15+
curToken: make([]byte, 0, 100),
16+
}
17+
}
18+
19+
// It only returns uppercased identifiers and keywords. It's used to search for some specified keywords.
20+
// It doesn't need to strict but it needs to be fast enough.
21+
func (l *Lexer) NextToken() string {
22+
l.curToken = l.curToken[:0]
23+
inSingleLineComment, inMultiLineComment, inSingleQuote, inDoubleQuote := false, false, false, false
24+
for ; l.curIdx < len(l.sql); l.curIdx++ {
25+
char := l.sql[l.curIdx]
26+
switch {
27+
case inSingleLineComment:
28+
if char == '\n' {
29+
inSingleLineComment = false
30+
}
31+
case inMultiLineComment:
32+
if char == '*' {
33+
if l.curIdx+1 < len(l.sql) && l.sql[l.curIdx+1] == '/' {
34+
inMultiLineComment = false
35+
l.curIdx++
36+
}
37+
}
38+
case inSingleQuote:
39+
if char == '\\' {
40+
l.curIdx++
41+
} else if char == '\'' {
42+
inSingleQuote = false
43+
}
44+
case inDoubleQuote:
45+
if char == '\\' {
46+
l.curIdx++
47+
} else if char == '"' {
48+
inDoubleQuote = false
49+
}
50+
case char == '-' && l.curIdx+1 < len(l.sql) && l.sql[l.curIdx+1] == '-':
51+
l.curIdx++
52+
inSingleLineComment = true
53+
case char == '/' && l.curIdx+1 < len(l.sql) && l.sql[l.curIdx+1] == '*':
54+
l.curIdx++
55+
inMultiLineComment = true
56+
case char == '\'':
57+
inSingleQuote = true
58+
case char == '"':
59+
inDoubleQuote = true
60+
case char >= 'a' && char <= 'z':
61+
l.curToken = append(l.curToken, char-'a'+'A')
62+
case char >= 'A' && char <= 'Z' || char == '_':
63+
l.curToken = append(l.curToken, char)
64+
default:
65+
if len(l.curToken) > 0 {
66+
l.curIdx++
67+
return string(l.curToken)
68+
}
69+
}
70+
}
71+
72+
if len(l.curToken) > 0 {
73+
return string(l.curToken)
74+
}
75+
return ""
76+
}

pkg/util/lex/lex_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2024 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package lex
5+
6+
import (
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestNextToken(t *testing.T) {
13+
tests := []struct {
14+
sql string
15+
tokens []string
16+
}{
17+
{
18+
sql: `SELECT * FROM table_name`,
19+
tokens: []string{"SELECT", "FROM", "TABLE_NAME"},
20+
},
21+
{
22+
sql: `-- comment
23+
/* comment * / "
24+
*/ SELECT
25+
-- comment
26+
(*)
27+
FROM table_name`,
28+
tokens: []string{"SELECT", "FROM", "TABLE_NAME"},
29+
},
30+
{
31+
sql: ` SELECT
32+
"string /* */
33+
" * 'string'
34+
FROM '"' "'" '\\' '\'' (table_name)`,
35+
tokens: []string{"SELECT", "FROM", "TABLE_NAME"},
36+
},
37+
{
38+
sql: ` select 123.4e-5 / (1 - 0.9) + @@hello_world 中文`,
39+
tokens: []string{"SELECT", "E", "HELLO_WORLD"},
40+
},
41+
{
42+
sql: `sEleCt ** from; t5ble_name`,
43+
tokens: []string{"SELECT", "FROM", "T", "BLE_NAME"},
44+
},
45+
}
46+
47+
for i, test := range tests {
48+
l := NewLexer(test.sql)
49+
tokens := make([]string, 0, len(test.tokens))
50+
for {
51+
token := l.NextToken()
52+
if len(token) == 0 {
53+
break
54+
}
55+
tokens = append(tokens, token)
56+
}
57+
require.Equal(t, test.tokens, tokens, "case %d", i)
58+
}
59+
}

0 commit comments

Comments
 (0)