Skip to content

Commit 1b9f80d

Browse files
authored
Merge pull request #3204 from dolthub/elian/9812
dolthub/dolt#9812: Coalesce `IN` and `=` operator logic
2 parents 0ca241f + 8efd2f7 commit 1b9f80d

File tree

3 files changed

+126
-35
lines changed

3 files changed

+126
-35
lines changed

enginetest/queries/script_queries.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,99 @@ type ScriptTestAssertion struct {
120120
// Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of
121121
// the tests.
122122
var ScriptTests = []ScriptTest{
123+
{
124+
// https://github.com/dolthub/dolt/issues/9836
125+
Skip: true,
126+
Name: "Ordering by pk does not change the order of results",
127+
SetUpScript: []string{
128+
"CREATE TABLE test(pk VARCHAR(50) PRIMARY KEY)",
129+
"INSERT INTO test VALUES (' 3 12 4'), ('3. 12 4'), ('3.2 12 4'), ('-3.1234'), ('-3.1a'), ('-5+8'), ('+3.1234')",
130+
},
131+
Assertions: []ScriptTestAssertion{
132+
{
133+
Query: "SELECT pk FROM test ORDER BY pk",
134+
Expected: []sql.Row{{" 3 12 4"}, {"-3.1234"}, {"-3.1a"}, {"-5+8"}, {"+3.1234"}, {"3. 12 4"}, {"3.2 12 4"}},
135+
},
136+
},
137+
},
138+
{
139+
// https://github.com/dolthub/dolt/issues/9812
140+
Name: "String-to-number comparison operators should behave consistently",
141+
Assertions: []ScriptTestAssertion{
142+
{
143+
Dialect: "mysql",
144+
Query: "SELECT ('A') = (0)",
145+
Expected: []sql.Row{{true}},
146+
//ExpectedWarningsCount: 1,
147+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
148+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
149+
},
150+
{
151+
Dialect: "mysql",
152+
Query: "SELECT ('A') IN (0)",
153+
Expected: []sql.Row{{true}},
154+
//ExpectedWarningsCount: 1,
155+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
156+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
157+
},
158+
{
159+
Dialect: "mysql",
160+
Query: "SELECT ('A') != (0)",
161+
Expected: []sql.Row{{false}},
162+
//ExpectedWarningsCount: 1,
163+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
164+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
165+
},
166+
{
167+
Dialect: "mysql",
168+
Query: "SELECT ('A') <> (0)",
169+
Expected: []sql.Row{{false}},
170+
//ExpectedWarningsCount: 1,
171+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
172+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
173+
},
174+
{
175+
Dialect: "mysql",
176+
Query: "SELECT ('A') < (0)",
177+
Expected: []sql.Row{{false}},
178+
//ExpectedWarningsCount: 1,
179+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
180+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
181+
},
182+
{
183+
Dialect: "mysql",
184+
Query: "SELECT ('A') <= (0)",
185+
Expected: []sql.Row{{true}},
186+
//ExpectedWarningsCount: 1,
187+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
188+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
189+
},
190+
{
191+
Dialect: "mysql",
192+
Query: "SELECT ('A') > (0)",
193+
Expected: []sql.Row{{false}},
194+
//ExpectedWarningsCount: 1,
195+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
196+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
197+
},
198+
{
199+
Dialect: "mysql",
200+
Query: "SELECT ('A') >= (0)",
201+
Expected: []sql.Row{{true}},
202+
//ExpectedWarningsCount: 1,
203+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
204+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
205+
},
206+
{
207+
Dialect: "mysql",
208+
Query: "SELECT ('A') NOT IN (0)",
209+
Expected: []sql.Row{{false}},
210+
//ExpectedWarningsCount: 1,
211+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
212+
//ExpectedWarningMessageSubstring: "Truncated incorrect double value: A",
213+
},
214+
},
215+
},
123216
{
124217
// https://github.com/dolthub/dolt/issues/9794
125218
Name: "UPDATE with TRIM function on TEXT column",
@@ -11662,6 +11755,8 @@ select * from t1 except (
1166211755
{"5.932887e7abc", float32(5.932887e+07)},
1166311756
{"a1a1", float32(0)},
1166411757
},
11758+
//ExpectedWarningsCount: 12,
11759+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1166511760
},
1166611761
{
1166711762
Dialect: "mysql",
@@ -11686,6 +11781,8 @@ select * from t1 except (
1168611781
{"5.932887e7abc", 5.932887e+07},
1168711782
{"a1a1", 0.0},
1168811783
},
11784+
//ExpectedWarningsCount: 12,
11785+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1168911786
},
1169011787
{
1169111788
Dialect: "mysql",
@@ -11710,6 +11807,8 @@ select * from t1 except (
1171011807
{"5.932887e7abc", 5},
1171111808
{"a1a1", 0},
1171211809
},
11810+
//ExpectedWarningsCount: 12,
11811+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1171311812
},
1171411813
{
1171511814
Dialect: "mysql",
@@ -11734,6 +11833,8 @@ select * from t1 except (
1173411833
{"5.932887e7abc", uint64(5)},
1173511834
{"a1a1", uint64(0)},
1173611835
},
11836+
//ExpectedWarningsCount: 19,
11837+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1173711838
},
1173811839
{
1173911840
Dialect: "mysql",
@@ -11758,10 +11859,13 @@ select * from t1 except (
1175811859
{"5.932887e7abc", "59328870.000"},
1175911860
{"a1a1", "0.000"},
1176011861
},
11862+
//ExpectedWarningsCount: 13,
11863+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1176111864
},
1176211865
{
1176311866
Query: "select * from test01 where pk in ('11')",
1176411867
Expected: []sql.Row{{"11"}},
11868+
//ExpectedWarningsCount: 0,
1176511869
},
1176611870
{
1176711871
// https://github.com/dolthub/dolt/issues/9739
@@ -11774,6 +11878,8 @@ select * from t1 except (
1177411878
{"11d"},
1177511879
{"11wha?"},
1177611880
},
11881+
//ExpectedWarningsCount: 12,
11882+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1177711883
},
1177811884
{
1177911885
// https://github.com/dolthub/dolt/issues/9739
@@ -11785,6 +11891,8 @@ select * from t1 except (
1178511891
{" 3. 12 4"},
1178611892
{"3. 12 4"},
1178711893
},
11894+
//ExpectedWarningsCount: 12,
11895+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1178811896
},
1178911897
{
1179011898
// https://github.com/dolthub/dolt/issues/9739
@@ -11798,20 +11906,26 @@ select * from t1 except (
1179811906
{"+3.1234"},
1179911907
{"3. 12 4"},
1180011908
},
11909+
//ExpectedWarningsCount: 20,
11910+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1180111911
},
1180211912
{
1180311913
// https://github.com/dolthub/dolt/issues/9739
1180411914
Skip: true,
1180511915
Dialect: "mysql",
1180611916
Query: "select * from test02 where pk in ('11asdf')",
1180711917
Expected: []sql.Row{{"11"}},
11918+
//ExpectedWarningsCount: 1,
11919+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1180811920
},
1180911921
{
1181011922
// https://github.com/dolthub/dolt/issues/9739
1181111923
Skip: true,
1181211924
Dialect: "mysql",
1181311925
Query: "select * from test02 where pk='11.12asdf'",
1181411926
Expected: []sql.Row{},
11927+
//ExpectedWarningsCount: 1,
11928+
//ExpectedWarning: mysql.ERTruncatedWrongValue,
1181511929
},
1181611930
},
1181711931
},

sql/expression/comparison.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) {
141141
return c.Left().Type().Compare(ctx, left, right)
142142
}
143143

144-
l, r, compareType, err := c.castLeftAndRight(ctx, left, right)
144+
l, r, compareType, err := c.CastLeftAndRight(ctx, left, right)
145145
if err != nil {
146146
return 0, err
147147
}
@@ -171,7 +171,7 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{
171171
return left, right, nil
172172
}
173173

174-
func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
174+
func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) {
175175
leftType := c.Left().Type()
176176
rightType := c.Right().Type()
177177

@@ -452,7 +452,7 @@ func (e *NullSafeEquals) Compare(ctx *sql.Context, row sql.Row) (int, error) {
452452
}
453453

454454
var compareType sql.Type
455-
left, right, compareType, err = e.castLeftAndRight(ctx, left, right)
455+
left, right, compareType, err = e.CastLeftAndRight(ctx, left, right)
456456
if err != nil {
457457
return 0, err
458458
}

sql/expression/in.go

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ func NewInTuple(left sql.Expression, right sql.Expression) *InTuple {
6161

6262
// Eval implements the Expression interface.
6363
func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
64-
typ := in.Left().Type().Promote()
65-
leftElems := types.NumColumns(typ)
64+
leftElems := types.NumColumns(in.Left().Type())
6665
originalLeft, err := in.Left().Eval(ctx, row)
6766
if err != nil {
6867
return nil, err
@@ -78,11 +77,6 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
7877
// also if no match is found in the list and one of the expressions in the list is NULL.
7978
rightNull := false
8079

81-
left, _, err := typ.Convert(ctx, originalLeft)
82-
if err != nil {
83-
return nil, err
84-
}
85-
8680
switch right := in.Right().(type) {
8781
case Tuple:
8882
for _, el := range right {
@@ -102,31 +96,14 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
10296
continue
10397
}
10498

105-
var cmp int
106-
elType := el.Type()
107-
if types.IsDecimal(elType) || types.IsFloat(elType) {
108-
rtyp := el.Type().Promote()
109-
left, err := types.ConvertOrTruncate(ctx, left, rtyp)
110-
if err != nil {
111-
return nil, err
112-
}
113-
right, err := types.ConvertOrTruncate(ctx, originalRight, rtyp)
114-
if err != nil {
115-
return nil, err
116-
}
117-
cmp, err = rtyp.Compare(ctx, left, right)
118-
if err != nil {
119-
return nil, err
120-
}
121-
} else {
122-
right, err := types.ConvertOrTruncate(ctx, originalRight, typ)
123-
if err != nil {
124-
return nil, err
125-
}
126-
cmp, err = typ.Compare(ctx, left, right)
127-
if err != nil {
128-
return nil, err
129-
}
99+
comp := newComparison(NewLiteral(originalLeft, in.Left().Type()), NewLiteral(originalRight, el.Type()))
100+
l, r, compareType, err := comp.CastLeftAndRight(ctx, originalLeft, originalRight)
101+
if err != nil {
102+
return nil, err
103+
}
104+
cmp, err := compareType.Compare(ctx, l, r)
105+
if err != nil {
106+
return nil, err
130107
}
131108

132109
if cmp == 0 {

0 commit comments

Comments
 (0)