Skip to content

Commit eed2800

Browse files
authored
make bit type columns with default values round trippable (#2972)
1 parent 7ef3cc8 commit eed2800

File tree

3 files changed

+59
-14
lines changed

3 files changed

+59
-14
lines changed

enginetest/queries/create_table_queries.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ var CreateTableQueries = []WriteQueryTest{
270270
SelectQuery: `SHOW CREATE TABLE t1`,
271271
ExpectedSelect: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `pk` varbinary(10) NOT NULL,\n PRIMARY KEY (`pk`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
272272
},
273+
{
274+
WriteQuery: `create table t1 (pk bit(2) default 2)`,
275+
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
276+
SelectQuery: `SHOW CREATE TABLE t1`,
277+
ExpectedSelect: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `pk` bit(2) DEFAULT b'10'\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
278+
},
273279
}
274280

275281
var CreateTableScriptTests = []ScriptTest{

enginetest/queries/script_queries.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8410,6 +8410,32 @@ where
84108410
},
84118411
},
84128412
},
8413+
{
8414+
Name: "bit default value",
8415+
Dialect: "mysql",
8416+
SetUpScript: []string{
8417+
"create table t (i int primary key, b bit(2) default 2);",
8418+
"insert into t(i) values (1);",
8419+
"create table tt (b bit(2) default 2 primary key);",
8420+
"insert into tt values ();",
8421+
},
8422+
Assertions: []ScriptTestAssertion{
8423+
{
8424+
Skip: true, // this fails on server engine, even when skipped
8425+
Query: "select * from t;",
8426+
Expected: []sql.Row{
8427+
{1, uint8(2)},
8428+
},
8429+
},
8430+
{
8431+
Skip: true, // this fails on server engine, even when skipped
8432+
Query: "select * from tt;",
8433+
Expected: []sql.Row{
8434+
{uint8(2)},
8435+
},
8436+
},
8437+
},
8438+
},
84138439
}
84148440

84158441
var SpatialScriptTests = []ScriptTest{

sql/rowexec/show_iters.go

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,25 @@ type NameAndSchema interface {
382382
Schema() sql.Schema
383383
}
384384

385+
func convertColumnDefaultToString(ctx *sql.Context, def *sql.ColumnDefaultValue) (string, error) {
386+
// TODO : string literals should have character set introducer
387+
colDefaultStr := def.String()
388+
defType := def.Type()
389+
390+
// These types do not need to be quoted
391+
if !def.IsLiteral() || colDefaultStr == "NULL" || types.IsTime(defType) || types.IsText(defType) {
392+
return colDefaultStr, nil
393+
}
394+
v, err := def.Eval(ctx, nil)
395+
if err != nil {
396+
return "", err
397+
}
398+
if types.IsBit(def.OutType) {
399+
return fmt.Sprintf("b'%b'", v), nil
400+
}
401+
return fmt.Sprintf("'%v'", v), nil
402+
}
403+
385404
func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, table sql.Table, schema sql.Schema, pkSchema sql.PrimaryKeySchema) (string, error) {
386405
colStmts := make([]string, len(schema))
387406
var primaryKeyCols []string
@@ -395,26 +414,20 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab
395414
tableCollation := table.Collation()
396415
for i, col := range schema {
397416
var colDefaultStr string
417+
var err error
398418
if col.Default != nil && col.Generated == nil {
399419
// TODO : string literals should have character set introducer
400-
colDefaultStr = col.Default.String()
401-
if colDefaultStr != "NULL" && col.Default.IsLiteral() && !types.IsTime(col.Default.Type()) && !types.IsText(col.Default.Type()) {
402-
v, err := col.Default.Eval(ctx, nil)
403-
if err != nil {
404-
return "", err
405-
}
406-
colDefaultStr = fmt.Sprintf("'%v'", v)
420+
colDefaultStr, err = convertColumnDefaultToString(ctx, col.Default)
421+
if err != nil {
422+
return "", err
407423
}
408424
}
425+
409426
var onUpdateStr string
410427
if col.OnUpdate != nil {
411-
onUpdateStr = col.OnUpdate.String()
412-
if onUpdateStr != "NULL" && col.OnUpdate.IsLiteral() && !types.IsTime(col.OnUpdate.Type()) && !types.IsText(col.OnUpdate.Type()) {
413-
v, err := col.OnUpdate.Eval(ctx, nil)
414-
if err != nil {
415-
return "", err
416-
}
417-
onUpdateStr = fmt.Sprintf("'%v'", v)
428+
onUpdateStr, err = convertColumnDefaultToString(ctx, col.OnUpdate)
429+
if err != nil {
430+
return "", err
418431
}
419432
}
420433

0 commit comments

Comments
 (0)