Skip to content

Commit b4366f3

Browse files
authored
Merge pull request #3258 from dolthub/elian/9935
dolthub/dolt#9935: Add fix for boolean evaluation in analyzer for EXISTS
2 parents 82453ba + 4224fe0 commit b4366f3

File tree

2 files changed

+83
-28
lines changed

2 files changed

+83
-28
lines changed

enginetest/queries/script_queries.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,53 @@ type ScriptTestAssertion struct {
122122
// Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of
123123
// the tests.
124124
var ScriptTests = []ScriptTest{
125+
{
126+
// https://github.com/dolthub/dolt/issues/9935
127+
Dialect: "mysql",
128+
Name: "Incorrect use of negation in AntiJoinIncludingNulls",
129+
SetUpScript: []string{
130+
"CREATE TABLE t0(c0 INT);",
131+
"INSERT INTO t0(c0) VALUES(1);",
132+
},
133+
Assertions: []ScriptTestAssertion{
134+
{
135+
Query: "SELECT * FROM t0 WHERE (! (1 || (EXISTS (SELECT 1))));",
136+
Expected: []sql.Row{},
137+
},
138+
{
139+
Query: "SELECT * FROM t0 WHERE (! (0 || (EXISTS (SELECT 1))));",
140+
Expected: []sql.Row{},
141+
},
142+
{
143+
Query: "SELECT * FROM t0 WHERE (! ((EXISTS (SELECT 1)) || 0));",
144+
Expected: []sql.Row{},
145+
},
146+
{
147+
Query: "SELECT * FROM t0 WHERE (! ((EXISTS (SELECT 1)) || 1));",
148+
Expected: []sql.Row{},
149+
},
150+
{
151+
Query: "SELECT * FROM t0 WHERE (! (1 && (EXISTS (SELECT 1))));",
152+
Expected: []sql.Row{},
153+
},
154+
{
155+
Query: "SELECT * FROM t0 WHERE (! (0 && (EXISTS (SELECT 1))));",
156+
Expected: []sql.Row{{1}},
157+
},
158+
{
159+
Query: "SELECT * FROM t0 WHERE (! (0 || (EXISTS (SELECT 1 FROM t0 WHERE c0 = 2))));",
160+
Expected: []sql.Row{{1}},
161+
},
162+
{
163+
Query: "SELECT * FROM t0 WHERE (! (1 || (EXISTS (SELECT 1 FROM t0 WHERE c0 = 1))));",
164+
Expected: []sql.Row{},
165+
},
166+
{
167+
Query: "SELECT * FROM t0 WHERE (! (0 || (EXISTS (SELECT 1 FROM t0 WHERE c0 = 1))));",
168+
Expected: []sql.Row{},
169+
},
170+
},
171+
},
125172
{
126173
// https://github.com/dolthub/go-mysql-server/issues/3259
127174
Dialect: "mysql",

sql/analyzer/optimization_rules.go

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -237,37 +237,37 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
237237
expression.NewLessThanOrEqual(e.Val, e.Upper),
238238
), transform.NewTree, nil
239239
case *expression.Or:
240-
if isTrue(e.LeftChild) {
240+
if isTrue(ctx, e.LeftChild) {
241241
return e.LeftChild, transform.NewTree, nil
242242
}
243243

244-
if isTrue(e.RightChild) {
244+
if isTrue(ctx, e.RightChild) {
245245
return e.RightChild, transform.NewTree, nil
246246
}
247247

248-
if isFalse(e.LeftChild) && types.IsBoolean(e.RightChild.Type()) {
248+
if isFalse(ctx, e.LeftChild) && types.IsBoolean(e.RightChild.Type()) {
249249
return e.RightChild, transform.NewTree, nil
250250
}
251251

252-
if isFalse(e.RightChild) && types.IsBoolean(e.LeftChild.Type()) {
252+
if isFalse(ctx, e.RightChild) && types.IsBoolean(e.LeftChild.Type()) {
253253
return e.LeftChild, transform.NewTree, nil
254254
}
255255

256256
return e, transform.SameTree, nil
257257
case *expression.And:
258-
if isFalse(e.LeftChild) {
258+
if isFalse(ctx, e.LeftChild) {
259259
return e.LeftChild, transform.NewTree, nil
260260
}
261261

262-
if isFalse(e.RightChild) {
262+
if isFalse(ctx, e.RightChild) {
263263
return e.RightChild, transform.NewTree, nil
264264
}
265265

266-
if isTrue(e.LeftChild) && types.IsBoolean(e.RightChild.Type()) {
266+
if isTrue(ctx, e.LeftChild) && types.IsBoolean(e.RightChild.Type()) {
267267
return e.RightChild, transform.NewTree, nil
268268
}
269269

270-
if isTrue(e.RightChild) && types.IsBoolean(e.LeftChild.Type()) {
270+
if isTrue(ctx, e.RightChild) && types.IsBoolean(e.LeftChild.Type()) {
271271
return e.LeftChild, transform.NewTree, nil
272272
}
273273

@@ -326,6 +326,16 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
326326
newRightUpper := expression.NewLiteral(valStr, e.RightChild.Type())
327327
newExpr := expression.NewAnd(expression.NewGreaterThanOrEqual(e.LeftChild, newRightLower), expression.NewLessThanOrEqual(e.LeftChild, newRightUpper))
328328
return newExpr, transform.NewTree, nil
329+
case *expression.Not:
330+
if lit, ok := e.Child.(*expression.Literal); ok {
331+
val, err := sql.ConvertToBool(ctx, lit.Value())
332+
if err != nil {
333+
// error while converting, keep as is
334+
return e, transform.SameTree, nil
335+
}
336+
return expression.NewLiteral(!val, e.Type()), transform.NewTree, nil
337+
}
338+
return e, transform.SameTree, nil
329339
case *expression.Literal, expression.Tuple, *expression.Interval, *expression.CollatedExpression, *expression.MatchAgainst:
330340
return e, transform.SameTree, nil
331341
default:
@@ -350,12 +360,12 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
350360
return nil, transform.SameTree, err
351361
}
352362

353-
if isFalse(e) {
363+
if isFalse(ctx, e) {
354364
emptyTable := plan.NewEmptyTableWithSchema(filter.Schema())
355365
return emptyTable, transform.NewTree, nil
356366
}
357367

358-
if isTrue(e) {
368+
if isTrue(ctx, e) {
359369
return filter.Child, transform.NewTree, nil
360370
}
361371

@@ -366,30 +376,28 @@ func simplifyFilters(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S
366376
})
367377
}
368378

369-
func isFalse(e sql.Expression) bool {
379+
func isFalse(ctx *sql.Context, e sql.Expression) bool {
370380
lit, ok := e.(*expression.Literal)
371-
if ok && lit != nil && lit.Type() == types.Boolean && lit.Value() != nil {
372-
switch v := lit.Value().(type) {
373-
case bool:
374-
return !v
375-
case int8:
376-
return v == sql.False
377-
}
381+
if !ok || lit == nil || lit.Value() == nil {
382+
return false
378383
}
379-
return false
384+
val, err := sql.ConvertToBool(ctx, lit.Value())
385+
if err != nil {
386+
return false
387+
}
388+
return !val
380389
}
381390

382-
func isTrue(e sql.Expression) bool {
391+
func isTrue(ctx *sql.Context, e sql.Expression) bool {
383392
lit, ok := e.(*expression.Literal)
384-
if ok && lit != nil && lit.Type() == types.Boolean && lit.Value() != nil {
385-
switch v := lit.Value().(type) {
386-
case bool:
387-
return v
388-
case int8:
389-
return v != sql.False
390-
}
393+
if !ok || lit == nil || lit.Value() == nil {
394+
return false
395+
}
396+
val, err := sql.ConvertToBool(ctx, lit.Value())
397+
if err != nil {
398+
return false
391399
}
392-
return false
400+
return val
393401
}
394402

395403
// pushNotFilters applies De'Morgan's laws to push NOT expressions as low

0 commit comments

Comments
 (0)