Skip to content

Commit 1f0225d

Browse files
author
James Cor
committed
fix and test
1 parent c6c4147 commit 1f0225d

File tree

3 files changed

+95
-41
lines changed

3 files changed

+95
-41
lines changed

enginetest/queries/script_queries.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
590590
{
591591
Query: "SELECT '123abc' in ('string', 1, 2, 123);",
592592
Expected: []sql.Row{{true}},
593-
ExpectedWarningsCount: 2, // MySQL only throws 1 warning
593+
ExpectedWarningsCount: 3, // MySQL only throws 1 warning
594594
ExpectedWarning: mysql.ERTruncatedWrongValue,
595595
ExpectedWarningMessageSubstring: "Truncated incorrect double value",
596596
},
@@ -642,7 +642,7 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN
642642
Expected: []sql.Row{{true}},
643643
ExpectedWarningsCount: 1,
644644
ExpectedWarning: mysql.ERTruncatedWrongValue,
645-
ExpectedWarningMessageSubstring: "Truncated incorrect double value: 123.456ABC",
645+
ExpectedWarningMessageSubstring: "Truncated incorrect decimal(65,30) value: 123.456ABC",
646646
},
647647
{
648648
Query: "SELECT '123.456e2' in (12345.6);",
@@ -7356,9 +7356,15 @@ CREATE TABLE tab3 (
73567356
},
73577357
},
73587358
{
7359+
// This actually matches MySQL behavior
7360+
Query: "select * from t where (f in (null, 0.8));",
7361+
Expected: []sql.Row{},
7362+
},
7363+
{
7364+
// This actually matches MySQL behavior
73597365
Query: "select count(*) from t where (f in (null, 0.8));",
73607366
Expected: []sql.Row{
7361-
{1},
7367+
{0},
73627368
},
73637369
},
73647370
{

sql/analyzer/apply_hash_in.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/dolthub/go-mysql-server/sql/expression"
2020
"github.com/dolthub/go-mysql-server/sql/plan"
2121
"github.com/dolthub/go-mysql-server/sql/transform"
22+
"github.com/dolthub/go-mysql-server/sql/types"
2223
)
2324

2425
func applyHashIn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
@@ -29,9 +30,7 @@ func applyHashIn(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, s
2930
}
3031

3132
e, same, err := transform.Expr(filter.Expression, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
32-
if e, ok := expr.(*expression.InTuple); ok &&
33-
hasSingleOutput(e.Left()) &&
34-
isStatic(e.Right()) {
33+
if e, ok := expr.(*expression.InTuple); ok && hasSingleOutput(e.Left()) && isStatic(e.Right()) && isConsistentType(e.Right()) {
3534
newe, err := expression.NewHashInTuple(ctx, e.Left(), e.Right())
3635
if err != nil {
3736
return nil, transform.SameTree, err
@@ -77,3 +76,24 @@ func isStatic(e sql.Expression) bool {
7776
}
7877
})
7978
}
79+
80+
func isConsistentType(expr sql.Expression) bool {
81+
tup, isTup := expr.(expression.Tuple)
82+
if !isTup {
83+
return true
84+
}
85+
var hasNumeric, hasString, hasTime bool
86+
for _, elem := range tup {
87+
eType := elem.Type()
88+
if types.IsNumber(eType) {
89+
hasNumeric = true
90+
} else if types.IsText(eType) {
91+
hasString = true
92+
} else if types.IsTime(eType) {
93+
hasTime = true
94+
}
95+
}
96+
// if there is a mixture of types, we cannot use hash
97+
// must have exactly one true
98+
return !((hasNumeric && hasString) || (hasNumeric && hasTime) || (hasString && hasTime))
99+
}

sql/expression/in.go

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ package expression
1717
import (
1818
"fmt"
1919

20-
"github.com/dolthub/vitess/go/mysql"
21-
2220
"github.com/dolthub/go-mysql-server/sql"
2321
"github.com/dolthub/go-mysql-server/sql/hash"
2422
"github.com/dolthub/go-mysql-server/sql/types"
@@ -130,58 +128,59 @@ func validateAndEvalRightTuple(ctx *sql.Context, lType sql.Type, right Tuple, ro
130128

131129
// Eval implements the Expression interface.
132130
func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
133-
leftVal, err := in.Left().Eval(ctx, row)
131+
lVal, err := in.Left().Eval(ctx, row)
134132
if err != nil {
135133
return nil, err
136134
}
137-
if leftVal == nil {
135+
if lVal == nil {
138136
return nil, nil
139137
}
140138

139+
lType := in.Left().Type()
140+
lColCount := types.NumColumns(lType)
141+
lLit := NewLiteral(lVal, lType)
142+
141143
right, isTuple := in.Right().(Tuple)
142144
if !isTuple {
143145
return nil, ErrUnsupportedInOperand.New(right)
144146
}
145147

146-
lType := in.Left().Type()
147-
rVals, cmpType, rHasNull, err := validateAndEvalRightTuple(ctx, lType, right, row)
148-
if err != nil {
149-
return nil, err
150-
}
148+
var rHasNull bool
149+
for _, el := range right {
150+
rType := el.Type()
151+
if rType == types.Null {
152+
rHasNull = true
153+
continue
154+
}
151155

152-
lv, _, lErr := cmpType.Convert(ctx, leftVal)
153-
if lErr != nil {
154-
if sql.ErrTruncatedIncorrect.Is(lErr) {
155-
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", lErr.Error())
156-
} else {
157-
lv = cmpType.Zero()
156+
// Nested tuples must have the same number of columns
157+
rColCount := types.NumColumns(rType)
158+
if rColCount != lColCount {
159+
return nil, sql.ErrInvalidOperandColumns.New(lColCount, rColCount)
158160
}
159-
}
160161

161-
for _, rVal := range rVals {
162+
rVal, rErr := el.Eval(ctx, row)
163+
if rErr != nil {
164+
return nil, rErr
165+
}
162166
if rVal == nil {
167+
rHasNull = true
163168
continue
164169
}
165-
rv, _, rErr := cmpType.Convert(ctx, rVal)
166-
if rErr != nil {
167-
if sql.ErrTruncatedIncorrect.Is(rErr) {
168-
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", rErr.Error())
169-
} else {
170-
rv = cmpType.Zero()
171-
}
172-
}
173-
cmp, cErr := cmpType.Compare(ctx, lv, rv)
170+
171+
cmpExpr := newComparison(lLit, NewLiteral(rVal, rType))
172+
res, cErr := cmpExpr.Compare(ctx, nil)
174173
if cErr != nil {
175-
continue
174+
return nil, cErr
176175
}
177-
if cmp == 0 {
176+
if res == 0 {
178177
return true, nil
179178
}
180179
}
180+
181181
if rHasNull {
182182
return nil, nil
183183
}
184-
185184
return false, nil
186185
}
187186

@@ -258,16 +257,45 @@ func newInMap(ctx *sql.Context, lType sql.Type, right Tuple) (map[uint64]struct{
258257
if lType == types.Null {
259258
return nil, nil, true, nil
260259
}
260+
lColCount := types.NumColumns(lType)
261261
if len(right) == 0 {
262262
return nil, nil, false, nil
263263
}
264-
rVals, cmpType, rHasNull, err := validateAndEvalRightTuple(ctx, lType, right, nil)
265-
if err != nil {
266-
return nil, nil, false, err
264+
// only non-nil elements are included
265+
rVals := make([]any, 0, len(right))
266+
var rHasNull bool
267+
for _, el := range right {
268+
rType := el.Type()
269+
rColCount := types.NumColumns(rType)
270+
if lColCount != rColCount {
271+
return nil, nil, false, sql.ErrInvalidOperandColumns.New(lColCount, rColCount)
272+
}
273+
if rType == types.Null {
274+
rHasNull = true
275+
continue
276+
}
277+
rVal, err := el.Eval(ctx, nil)
278+
if err != nil {
279+
return nil, nil, false, err
280+
}
281+
if rVal == nil {
282+
rHasNull = true
283+
continue
284+
}
285+
rVals = append(rVals, rVal)
267286
}
268-
elements := make(map[uint64]struct{})
269-
for _, v := range rVals {
270-
key, hErr := hash.HashOfSimple(ctx, v, cmpType)
287+
288+
var cmpType sql.Type
289+
if types.IsEnum(lType) || types.IsSet(lType) {
290+
cmpType = lType
291+
} else {
292+
// If we've made it this far, we are guaranteed that the right Tuple has a consistent set of types
293+
// (all numeric, string, or time), so it is enough to just compare against the first element of the right Tuple
294+
cmpType = types.GetCompareType(lType, right[0].Type())
295+
}
296+
elements := map[uint64]struct{}{}
297+
for _, rVal := range rVals {
298+
key, hErr := hash.HashOfSimple(ctx, rVal, cmpType)
271299
if hErr != nil {
272300
return nil, nil, false, hErr
273301
}

0 commit comments

Comments
 (0)