Skip to content

Commit 2e73239

Browse files
committed
fix bitmask conv
1 parent 17bebbc commit 2e73239

File tree

7 files changed

+100
-58
lines changed

7 files changed

+100
-58
lines changed

enginetest/queries/binlog_queries.go

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"strings"
2020

2121
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/types"
2223
)
2324

2425
var (
@@ -30,10 +31,33 @@ var (
3031
binlogNoFormatDescStmts = readBinlogTestFile("./testdata/binlog_no_format_desc.txt")
3132
)
3233

33-
// BinlogScripts contains test cases for the BINLOG statement. Test data is generated from real MariaDB binlog events by
34-
// running `bats binlog_maker.bats` in enginetest/testdata.
35-
// To add tests: add a @test to binlog_maker.bats, generate the .dat file, then add a test case here.
34+
// BinlogScripts contains test cases for the BINLOG statement. To add tests: add a @test to binlog_maker.bats, generate
35+
// the .txt file with BINLOG statements, then add a test case here with the corresponding setup.
3636
var BinlogScripts = []ScriptTest{
37+
{
38+
Name: "SET sql_mode with numeric bitmask from binlog",
39+
Assertions: []ScriptTestAssertion{
40+
{Query: "SET @@session.sql_mode=1411383296", Expected: []sql.Row{{types.OkResult{}}}},
41+
{Query: "SELECT @@session.sql_mode", Expected: []sql.Row{{"ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION,STRICT_TRANS_TABLES"}}},
42+
},
43+
},
44+
{
45+
Name: "SET collation variables with numeric IDs from binlog",
46+
Assertions: []ScriptTestAssertion{
47+
{Query: "SET @@session.collation_connection=33", Expected: []sql.Row{{types.OkResult{}}}},
48+
{Query: "SELECT @@session.collation_connection", Expected: []sql.Row{{"utf8mb3_general_ci"}}},
49+
{Query: "SELECT @@session.character_set_connection", Expected: []sql.Row{{"utf8mb3"}}},
50+
{Query: "SET @@session.collation_server=8", Expected: []sql.Row{{types.OkResult{}}}},
51+
{Query: "SELECT @@session.collation_server", Expected: []sql.Row{{"latin1_swedish_ci"}}},
52+
{Query: "SELECT @@session.character_set_server", Expected: []sql.Row{{"latin1"}}},
53+
// collation_database always returns the current database's collation. See sql/core.go:729-735
54+
{Query: "SET @@session.collation_database=33", Expected: []sql.Row{{types.OkResult{}}}},
55+
{Query: "SELECT @@session.collation_database", Expected: []sql.Row{{"utf8mb4_0900_bin"}}},
56+
// TODO: lc_time_names no-op
57+
{Query: "SET @@session.lc_time_names=0", Expected: []sql.Row{{types.OkResult{}}}},
58+
{Query: "SELECT @@session.lc_time_names", Expected: []sql.Row{{"0"}}},
59+
},
60+
},
3761
{
3862
Name: "BINLOG requires FORMAT_DESCRIPTION_EVENT first",
3963
SetUpScript: []string{
@@ -159,5 +183,13 @@ func readBinlogTestFile(path string) []string {
159183
return nil
160184
}
161185

162-
return strings.Split(content, "BINLOG '")
186+
parts := strings.Split(content, "BINLOG '")
187+
var stmts []string
188+
for i, part := range parts {
189+
if i == 0 && part == "" {
190+
continue
191+
}
192+
stmts = append(stmts, "BINLOG '"+part)
193+
}
194+
return stmts
163195
}

enginetest/queries/queries.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5031,6 +5031,8 @@ SELECT * FROM cte WHERE d = 2;`,
50315031
{"gtid_next", "AUTOMATIC"},
50325032
{"gtid_owned", ""},
50335033
{"gtid_purged", ""},
5034+
{"gtid_domain_id", 0},
5035+
{"gtid_seq_no", 0},
50345036
},
50355037
},
50365038
{

enginetest/testdata/binlog_maker.bats

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#!/usr/bin/env bats
2-
# Starts MariaDB with binlog enabled, executes operations, and extracts BINLOG statements to .txt files.
3-
# Run: bats binlog_maker.bats
2+
# Starts MariaDB with binlog enabled, executes operations, and extracts BINLOG statements to .txt files. `mariadb` and
3+
# `mariadb-binlog` should be available to use. This is intended to run in a unix environment, (if you're on Windows, run
4+
# in WSL), the script will find the relative directory to `testdata`. Tests are constructed as follows: query -> flush
5+
# -> extract_binlog_to_file.
46

57
definePORT() {
68
local base_port=$((2048 + ($$ % 4096)))

sql/collations.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package sql
1717
import (
1818
"fmt"
1919
"io"
20+
"strconv"
2021
"strings"
2122
"sync"
2223
"unicode/utf8"
@@ -974,13 +975,21 @@ type TypeWithCollation interface {
974975
}
975976

976977
// ConvertCollationID converts numeric collation IDs to their string names.
977-
func ConvertCollationID(val any) (any, error) {
978-
if _, ok := val.(string); ok {
979-
return val, nil
980-
}
981-
978+
func ConvertCollationID(val any) (string, error) {
982979
var collationID uint64
983980
switch v := val.(type) {
981+
case []byte:
982+
if n, err := strconv.ParseUint(string(v), 10, 64); err == nil {
983+
collationID = n
984+
} else {
985+
return string(v), nil
986+
}
987+
case string:
988+
if n, err := strconv.ParseUint(v, 10, 64); err == nil {
989+
collationID = n
990+
} else {
991+
return v, nil
992+
}
984993
case int8:
985994
collationID = uint64(v)
986995
case int16:
@@ -1002,10 +1011,9 @@ func ConvertCollationID(val any) (any, error) {
10021011
case uint64:
10031012
collationID = v
10041013
default:
1005-
return val, nil
1014+
return fmt.Sprintf("%v", val), nil
10061015
}
10071016

1008-
// Convert numeric ID to collation name
10091017
collation := CollationID(collationID).Collation()
10101018
return collation.Name, nil
10111019
}

sql/planbuilder/set.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,27 @@ func (b *Builder) setExprsToExpressions(inScope *scope, e ast.SetVarExprs) []sql
151151
}
152152
}
153153

154+
if sysVar, ok := setVar.(*expression.SystemVar); ok {
155+
if sqlVal, ok := setExpr.Expr.(*ast.SQLVal); ok && sqlVal.Type == ast.IntVal {
156+
switch strings.ToLower(sysVar.Name) {
157+
case "sql_mode":
158+
converted, err := sql.ConvertSqlModeBitmask(sqlVal.Val)
159+
if err != nil {
160+
b.handleErr(err)
161+
}
162+
setExpr.Expr = ast.NewStrVal([]byte(converted))
163+
case "collation_database", "collation_connection", "collation_server":
164+
converted, err := sql.ConvertCollationID(sqlVal.Val)
165+
if err != nil {
166+
b.handleErr(err)
167+
}
168+
setExpr.Expr = ast.NewStrVal([]byte(converted))
169+
case "lc_time_names":
170+
setExpr.Expr = ast.NewStrVal(sqlVal.Val)
171+
}
172+
}
173+
}
174+
154175
sysVarType, _ := setVar.Type().(sql.SystemVariableType)
155176
innerExpr, ok := b.simplifySetExpr(setExpr.Name, setScope, setExpr.Expr, sysVarType)
156177
if !ok {

sql/rowexec/rel_iters.go

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package rowexec
1616

1717
import (
1818
"errors"
19-
"fmt"
2019
"io"
2120
"strings"
2221

@@ -393,21 +392,18 @@ func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expr
393392
if err != nil {
394393
return err
395394
}
396-
397395
err = validateSystemVariableValue(sysVar.Name, val)
398396
if err != nil {
399397
return err
400398
}
399+
err = sysVar.Scope.SetValue(ctx, sysVar.Name, val)
400+
if err != nil {
401+
return err
402+
}
401403

402404
// Setting `character_set_connection` and `collation_connection` will set the corresponding variable
403405
// Setting `character_set_server` and `collation_server` will set the corresponding variable
404406
switch strings.ToLower(sysVar.Name) {
405-
case "sql_mode":
406-
val, err = sql.ConvertSqlModeBitmask(val)
407-
if err != nil {
408-
return err
409-
}
410-
return sysVar.Scope.SetValue(ctx, sysVar.Name, val)
411407
case "character_set_connection":
412408
if val == nil {
413409
return sysVar.Scope.SetValue(ctx, "collation_connection", val)
@@ -424,16 +420,8 @@ func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expr
424420
collationName := charset.DefaultCollation().Name()
425421
return sysVar.Scope.SetValue(ctx, "collation_connection", collationName)
426422
case "collation_connection":
427-
val, err = sql.ConvertCollationID(val)
428-
if err != nil {
429-
return err
430-
}
431-
err = sysVar.Scope.SetValue(ctx, sysVar.Name, val)
432-
if err != nil {
433-
return err
434-
}
435423
if val == nil {
436-
return sysVar.Scope.SetValue(ctx, "character_set_connection", nil)
424+
return sysVar.Scope.SetValue(ctx, "character_set_connection", val)
437425
}
438426
valStr, ok := val.(string)
439427
if !ok {
@@ -462,16 +450,8 @@ func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expr
462450
collationName := charset.DefaultCollation().Name()
463451
return sysVar.Scope.SetValue(ctx, "collation_server", collationName)
464452
case "collation_server":
465-
val, err = sql.ConvertCollationID(val)
466-
if err != nil {
467-
return err
468-
}
469-
err = sysVar.Scope.SetValue(ctx, sysVar.Name, val)
470-
if err != nil {
471-
return err
472-
}
473453
if val == nil {
474-
return sysVar.Scope.SetValue(ctx, "character_set_server", nil)
454+
return sysVar.Scope.SetValue(ctx, "character_set_server", val)
475455
}
476456
valStr, ok := val.(string)
477457
if !ok {
@@ -484,19 +464,6 @@ func setSystemVar(ctx *sql.Context, sysVar *expression.SystemVar, right sql.Expr
484464
}
485465
charsetName := collation.CharacterSet().Name()
486466
return sysVar.Scope.SetValue(ctx, "character_set_server", charsetName)
487-
case "collation_database":
488-
val, err = sql.ConvertCollationID(val)
489-
if err != nil {
490-
return err
491-
}
492-
return sysVar.Scope.SetValue(ctx, sysVar.Name, val)
493-
case "lc_time_names":
494-
// TODO: convert numeric locale ID to locale name
495-
switch val.(type) {
496-
case int8, int16, int, int32, int64, uint8, uint16, uint, uint32, uint64:
497-
val = fmt.Sprintf("%v", val)
498-
}
499-
return sysVar.Scope.SetValue(ctx, sysVar.Name, val)
500467
}
501468
return nil
502469
}

sql/sql_mode.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
package sql
1616

1717
import (
18+
"fmt"
1819
"sort"
20+
"strconv"
1921
"strings"
2022

2123
"github.com/dolthub/vitess/go/vt/sqlparser"
@@ -225,13 +227,21 @@ func (s *SqlMode) String() string {
225227
}
226228

227229
// ConvertSqlModeBitmask converts sql_mode values to their string representation.
228-
func ConvertSqlModeBitmask(val any) (any, error) {
229-
if _, ok := val.(string); ok {
230-
return val, nil
231-
}
232-
230+
func ConvertSqlModeBitmask(val any) (string, error) {
233231
var bitmask uint64
234232
switch v := val.(type) {
233+
case []byte:
234+
if n, err := strconv.ParseUint(string(v), 10, 64); err == nil {
235+
bitmask = n
236+
} else {
237+
return string(v), nil
238+
}
239+
case string:
240+
if n, err := strconv.ParseUint(v, 10, 64); err == nil {
241+
bitmask = n
242+
} else {
243+
return v, nil
244+
}
235245
case int8:
236246
bitmask = uint64(v)
237247
case int16:
@@ -253,7 +263,7 @@ func ConvertSqlModeBitmask(val any) (any, error) {
253263
case uint64:
254264
bitmask = v
255265
default:
256-
return val, nil
266+
return fmt.Sprintf("%v", val), nil
257267
}
258268

259269
var modes []string

0 commit comments

Comments
 (0)