Skip to content

Commit 443e9b5

Browse files
authored
backend, lex: use a simpler lexer to parse begin statements (#758)
1 parent 9bd6154 commit 443e9b5

File tree

4 files changed

+66
-63
lines changed

4 files changed

+66
-63
lines changed

pkg/proxy/backend/cmd_processor_exec.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@ package backend
55

66
import (
77
"encoding/binary"
8-
"strings"
98

109
"github.com/go-mysql-org/go-mysql/mysql"
11-
"github.com/pingcap/tidb/pkg/parser"
1210
"github.com/pingcap/tiproxy/lib/util/errors"
1311
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
12+
"github.com/pingcap/tiproxy/pkg/util/lex"
1413
"github.com/siddontang/go/hack"
1514
"go.uber.org/zap"
1615
)
@@ -361,10 +360,5 @@ func (cp *CmdProcessor) needHoldRequest(request []byte) bool {
361360
data = data[:len(data)-1]
362361
}
363362
query := hack.String(data)
364-
return isBeginStmt(query)
365-
}
366-
367-
func isBeginStmt(query string) bool {
368-
normalized := parser.Normalize(query, "ON")
369-
return strings.HasPrefix(normalized, "begin") || strings.HasPrefix(normalized, "start transaction")
363+
return lex.IsStartTxn(query)
370364
}

pkg/proxy/backend/cmd_processor_test.go

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -918,61 +918,6 @@ func TestHoldRequest(t *testing.T) {
918918
}
919919
}
920920

921-
func TestBeginStmt(t *testing.T) {
922-
tests := []struct {
923-
stmt string
924-
isBegin bool
925-
}{
926-
{
927-
stmt: "begin",
928-
isBegin: true,
929-
},
930-
{
931-
stmt: "BEGIN",
932-
isBegin: true,
933-
},
934-
{
935-
stmt: "begin optimistic as of timestamp now()",
936-
isBegin: true,
937-
},
938-
{
939-
stmt: " begin",
940-
isBegin: true,
941-
},
942-
{
943-
stmt: "start transaction",
944-
isBegin: true,
945-
},
946-
{
947-
stmt: "START transaction",
948-
isBegin: true,
949-
},
950-
{
951-
stmt: "start transaction with consistent snapshot",
952-
isBegin: true,
953-
},
954-
{
955-
stmt: "begin; select 1",
956-
isBegin: true,
957-
},
958-
{
959-
stmt: "/*+ some_hint */begin",
960-
isBegin: true,
961-
},
962-
{
963-
stmt: "commit",
964-
isBegin: false,
965-
},
966-
{
967-
stmt: "select 1; begin",
968-
isBegin: false,
969-
},
970-
}
971-
for _, test := range tests {
972-
require.Equal(t, test.isBegin, isBeginStmt(test.stmt), test.stmt)
973-
}
974-
}
975-
976921
// Test forwarding multi-statements works well.
977922
func TestMultiStmt(t *testing.T) {
978923
tc := newTCPConnSuite(t)

pkg/util/lex/filter.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,12 @@ var readOnlyKeywords = [][]string{
6666
func IsReadOnly(sql string) bool {
6767
return startsWithKeyword(sql, readOnlyKeywords)
6868
}
69+
70+
var startTxnKeywords = [][]string{
71+
{"START", "TRANSACTION"},
72+
{"BEGIN"},
73+
}
74+
75+
func IsStartTxn(sql string) bool {
76+
return startsWithKeyword(sql, startTxnKeywords)
77+
}

pkg/util/lex/filter_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,58 @@ func TestReadOnlySQL(t *testing.T) {
6060
require.Equal(t, test.readOnly, IsReadOnly(test.sql), test.sql)
6161
}
6262
}
63+
64+
func TestStartTxn(t *testing.T) {
65+
tests := []struct {
66+
stmt string
67+
isBegin bool
68+
}{
69+
{
70+
stmt: "begin",
71+
isBegin: true,
72+
},
73+
{
74+
stmt: "BEGIN",
75+
isBegin: true,
76+
},
77+
{
78+
stmt: "begin optimistic as of timestamp now()",
79+
isBegin: true,
80+
},
81+
{
82+
stmt: " begin",
83+
isBegin: true,
84+
},
85+
{
86+
stmt: "start transaction",
87+
isBegin: true,
88+
},
89+
{
90+
stmt: "START transaction",
91+
isBegin: true,
92+
},
93+
{
94+
stmt: "start transaction with consistent snapshot",
95+
isBegin: true,
96+
},
97+
{
98+
stmt: "begin; select 1",
99+
isBegin: true,
100+
},
101+
{
102+
stmt: "/*+ some_hint */begin",
103+
isBegin: true,
104+
},
105+
{
106+
stmt: "commit",
107+
isBegin: false,
108+
},
109+
{
110+
stmt: "select 1; begin",
111+
isBegin: false,
112+
},
113+
}
114+
for _, test := range tests {
115+
require.Equal(t, test.isBegin, IsStartTxn(test.stmt), test.stmt)
116+
}
117+
}

0 commit comments

Comments
 (0)