Skip to content

Commit aa14837

Browse files
authored
fix string to boolean comparison for HashInTuple expressions (#3237)
1 parent fb75d24 commit aa14837

File tree

8 files changed

+304
-141
lines changed

8 files changed

+304
-141
lines changed

enginetest/queries/script_queries.go

Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
403403
},
404404
},
405405
{
406-
Name: "string to number comparison correctly truncates",
407406
Dialect: "mysql",
407+
Name: "string to number comparison correctly truncates",
408408
Assertions: []ScriptTestAssertion{
409409
{
410410
Query: "SELECT 'A' = 0;",
@@ -517,12 +517,10 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
517517
Expected: []sql.Row{{true}},
518518
},
519519
{
520-
Skip: true,
521520
Query: "SELECT '1.9a' = 1.9;",
522521
Expected: []sql.Row{{true}},
523522
},
524523
{
525-
Skip: true,
526524
Query: "SELECT 1 where '1.9a' = 1.9;",
527525
Expected: []sql.Row{{1}},
528526
},
@@ -590,11 +588,25 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
590588
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A123",
591589
},
592590
{
593-
Query: "SELECT 0 in ('A123');",
591+
Query: "SELECT '123abc' in ('string', 1, 2, 123);",
592+
Expected: []sql.Row{{true}},
593+
ExpectedWarningsCount: 3, // MySQL only throws 1 warning
594+
ExpectedWarning: mysql.ERTruncatedWrongValue,
595+
ExpectedWarningMessageSubstring: "Truncated incorrect double value",
596+
},
597+
{
598+
Query: "SELECT 123 in ('string', 1, 2, '123abc');",
599+
Expected: []sql.Row{{true}},
600+
ExpectedWarningsCount: 2,
601+
ExpectedWarning: mysql.ERTruncatedWrongValue,
602+
ExpectedWarningMessageSubstring: "Truncated incorrect double value",
603+
},
604+
{
605+
Query: "SELECT '123A' in (123);",
594606
Expected: []sql.Row{{true}},
595607
ExpectedWarningsCount: 1,
596608
ExpectedWarning: mysql.ERTruncatedWrongValue,
597-
ExpectedWarningMessageSubstring: "Truncated incorrect double value: A123",
609+
ExpectedWarningMessageSubstring: "Truncated incorrect double value: 123A",
598610
},
599611
{
600612
Query: "SELECT '123.456' in (123);",
@@ -605,7 +617,26 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
605617
Expected: []sql.Row{{true}},
606618
},
607619
{
608-
// TODO: 123.456 is converted to a DECIMAL by Builder.ConvertVal, when it should be a DOUBLE
620+
Query: "SELECT 123.456 in (123.456);",
621+
Expected: []sql.Row{{true}},
622+
},
623+
{
624+
Query: "SELECT 123.45 in (123.4);",
625+
Expected: []sql.Row{{false}},
626+
},
627+
{
628+
Query: "SELECT 123.45 in (123.5);",
629+
Expected: []sql.Row{{false}},
630+
},
631+
{
632+
Query: "SELECT '123.45a' in (123.5);",
633+
Expected: []sql.Row{{false}},
634+
},
635+
{
636+
Query: "SELECT '123.45a' in (123.4);",
637+
Expected: []sql.Row{{false}},
638+
},
639+
{
609640
SkipResultCheckOnServerEngine: true, // TODO: warnings do not make it to server engine
610641
Query: "SELECT '123.456ABC' in (123.456);",
611642
Expected: []sql.Row{{true}},
@@ -1229,6 +1260,59 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
12291260
},
12301261
},
12311262
},
1263+
{
1264+
Dialect: "mysql",
1265+
Name: "complicated string to numeric conversion",
1266+
SetUpScript: []string{
1267+
"CREATE TABLE t0(c INT);",
1268+
"INSERT INTO t0 VALUES (1);",
1269+
"CREATE TABLE t1(c VARCHAR(500));",
1270+
"INSERT INTO t1 VALUES ('1a');",
1271+
"CREATE TABLE t2(c0 INT , c1 BOOLEAN , c2 BOOLEAN , c3 INT , placeholder0 INT , placeholder1 VARCHAR(500) , placeholder2 VARCHAR(500) , PRIMARY KEY(placeholder0));",
1272+
"CREATE TABLE t3(c0 INT , c1 VARCHAR(500) , c2 BOOLEAN , c3 VARCHAR(500) , placeholder0 BOOLEAN , placeholder1 INT , placeholder2 VARCHAR(500));",
1273+
"INSERT INTO t3 VALUES (7, '0y4', TRUE, '5y', TRUE, 5, 'p9c');",
1274+
"INSERT INTO t3 VALUES (1, '4', TRUE, '4H', FALSE, 9, 'Zy4');",
1275+
"INSERT INTO t3 VALUES (10, '1a', FALSE, 'pYE', FALSE, 3, '0awX');",
1276+
"INSERT INTO t3 VALUES (8, 'J', TRUE, 'LE', TRUE, 9, 'YEqQ');",
1277+
"INSERT INTO t2 VALUES (10, FALSE, TRUE, 2, 2, 'nfxF', 'xvC');",
1278+
"INSERT INTO t2 VALUES (10, TRUE, TRUE, 10, 1, 'rlQT', 'W');",
1279+
},
1280+
Assertions: []ScriptTestAssertion{
1281+
{
1282+
Query: "SELECT * FROM t0, t1 WHERE (t1.c IN (true));",
1283+
Expected: []sql.Row{
1284+
{1, "1a"},
1285+
},
1286+
},
1287+
{
1288+
Query: "SELECT * FROM t3 INNER JOIN t2 ON ((((t3.c0) = ((EXTRACT(YEAR FROM DATE_ADD(DATE '2000-01-01', INTERVAL ( BIT_LENGTH(( MOD(t2.c3 + ( t2.c3 + ( BIT_COUNT(t2.c3) ) * 3 - CAST(( NOT (t2.c0 XOR t2.c2) ) AS SIGNED) ) * 2, 100 + t2.c3) ) ^ t2.c3) ) DAY)) % (t2.c3 + 1))))) >= (((t3.c2) < ((((((('Bs./')OR('wZ')) IN ((('1066274936')OR('')))))OR((((t3.c1 IN (true)))<>(((t3.c0)OR(( COALESCE(NULLIF(t3.c3, ''), t3.c1) ))))))))))));",
1289+
Expected: []sql.Row{
1290+
{7, "0y4", 1, "5y", 1, 5, "p9c", 10, 1, 1, 10, 1, "rlQT", "W"},
1291+
{1, "4", 1, "4H", 0, 9, "Zy4", 10, 1, 1, 10, 1, "rlQT", "W"},
1292+
{10, "1a", 0, "pYE", 0, 3, "0awX", 10, 1, 1, 10, 1, "rlQT", "W"},
1293+
{8, "J", 1, "LE", 1, 9, "YEqQ", 10, 1, 1, 10, 1, "rlQT", "W"},
1294+
{7, "0y4", 1, "5y", 1, 5, "p9c", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1295+
{1, "4", 1, "4H", 0, 9, "Zy4", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1296+
{10, "1a", 0, "pYE", 0, 3, "0awX", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1297+
{8, "J", 1, "LE", 1, 9, "YEqQ", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1298+
},
1299+
},
1300+
{
1301+
Query: "SELECT * FROM t3 CROSS JOIN t2 WHERE ((((t3.c0) = ((EXTRACT(YEAR FROM DATE_ADD(DATE '2000-01-01', INTERVAL ( BIT_LENGTH(( MOD(t2.c3 + ( t2.c3 + ( BIT_COUNT(t2.c3) ) * 3 - CAST(( NOT (t2.c0 XOR t2.c2) ) AS SIGNED) ) * 2, 100 + t2.c3) ) ^ t2.c3) ) DAY)) % (t2.c3 + 1))))) >= (((t3.c2) < ((((((('Bs./')OR('wZ')) IN ((('1066274936')OR('')))))OR((((t3.c1 IN (true)))<>(((t3.c0)OR(( COALESCE(NULLIF(t3.c3, ''), t3.c1) ))))))))))));",
1302+
Expected: []sql.Row{
1303+
{7, "0y4", 1, "5y", 1, 5, "p9c", 10, 1, 1, 10, 1, "rlQT", "W"},
1304+
{1, "4", 1, "4H", 0, 9, "Zy4", 10, 1, 1, 10, 1, "rlQT", "W"},
1305+
{10, "1a", 0, "pYE", 0, 3, "0awX", 10, 1, 1, 10, 1, "rlQT", "W"},
1306+
{8, "J", 1, "LE", 1, 9, "YEqQ", 10, 1, 1, 10, 1, "rlQT", "W"},
1307+
{7, "0y4", 1, "5y", 1, 5, "p9c", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1308+
{1, "4", 1, "4H", 0, 9, "Zy4", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1309+
{10, "1a", 0, "pYE", 0, 3, "0awX", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1310+
{8, "J", 1, "LE", 1, 9, "YEqQ", 10, 0, 1, 2, 2, "nfxF", "xvC"},
1311+
},
1312+
},
1313+
},
1314+
},
1315+
12321316
{
12331317
// https://github.com/dolthub/dolt/issues/9794
12341318
Name: "UPDATE with TRIM function on TEXT column",
@@ -7277,6 +7361,13 @@ CREATE TABLE tab3 (
72777361
Query: "select * from t where (f in (null, 0.8));",
72787362
Expected: []sql.Row{},
72797363
},
7364+
{
7365+
// This actually matches MySQL behavior
7366+
Query: "select count(*) from t where (f in (null, 0.8));",
7367+
Expected: []sql.Row{
7368+
{0},
7369+
},
7370+
},
72807371
{
72817372
// select count to avoid floating point comparison
72827373
Query: "select count(*) from t where (f in (null, cast(0.8 as float)));",
@@ -7323,6 +7414,32 @@ CREATE TABLE tab3 (
73237414
},
73247415
},
73257416
},
7417+
{
7418+
Name: "hash in tuple picks correct type and skips mixed types",
7419+
Dialect: "mysql",
7420+
SetUpScript: []string{
7421+
"create table t (v varchar(10));",
7422+
"insert into t values ('abc'), ('def'), ('ghi');",
7423+
},
7424+
Assertions: []ScriptTestAssertion{
7425+
{
7426+
Query: "select * from t where (v in ('xyz')) order by v;",
7427+
Expected: []sql.Row{},
7428+
},
7429+
{
7430+
Query: "select * from t where (v in (0, 'xyz')) order by v;",
7431+
Expected: []sql.Row{
7432+
{"abc"},
7433+
{"def"},
7434+
{"ghi"},
7435+
},
7436+
},
7437+
{
7438+
Query: "select * from t where (v in (1, 'xyz')) order by v;",
7439+
Expected: []sql.Row{},
7440+
},
7441+
},
7442+
},
73267443
{
73277444
Name: "strings in tuple are properly hashed",
73287445
Dialect: "mysql",

sql/analyzer/apply_hash_in.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/dolthub/go-mysql-server/sql/expression"
2020
"github.com/dolthub/go-mysql-server/sql/plan"
2121
"github.com/dolthub/go-mysql-server/sql/transform"
22+
"github.com/dolthub/go-mysql-server/sql/types"
2223
)
2324

2425
func applyHashIn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
@@ -29,9 +30,7 @@ func applyHashIn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, s
2930
}
3031

3132
e, same, err := transform.Expr(filter.Expression, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
32-
if e, ok := expr.(*expression.InTuple); ok &&
33-
hasSingleOutput(e.Left()) &&
34-
isStatic(e.Right()) {
33+
if e, ok := expr.(*expression.InTuple); ok && hasSingleOutput(e.Left()) && isStatic(e.Right()) && isConsistentType(e.Right()) {
3534
newe, err := expression.NewHashInTuple(ctx, e.Left(), e.Right())
3635
if err != nil {
3736
return nil, transform.SameTree, err
@@ -77,3 +76,24 @@ func isStatic(e sql.Expression) bool {
7776
}
7877
})
7978
}
79+
80+
func isConsistentType(expr sql.Expression) bool {
81+
tup, isTup := expr.(expression.Tuple)
82+
if !isTup {
83+
return true
84+
}
85+
var hasNumeric, hasString, hasTime bool
86+
for _, elem := range tup {
87+
eType := elem.Type()
88+
if types.IsNumber(eType) {
89+
hasNumeric = true
90+
} else if types.IsText(eType) {
91+
hasString = true
92+
} else if types.IsTime(eType) {
93+
hasTime = true
94+
}
95+
}
96+
// if there is a mixture of types, we cannot use hash
97+
// must have exactly one true
98+
return !((hasNumeric && hasString) || (hasNumeric && hasTime) || (hasString && hasTime))
99+
}

sql/expression/function/ceil_round_floor.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
179179
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
180180
}
181181
}
182-
183182
// if it's number type and not float value, it does not need ceil-ing
184183
switch num := child.(type) {
185184
case float32:

0 commit comments

Comments
 (0)