Skip to content

Commit 8c2ec97

Browse files
authored
Merge pull request #3064 from dolthub/fulghum/expr
Abstract `IsNull` and `IsNotNull` expression logic
2 parents e58117e + cddc0b6 commit 8c2ec97

File tree

11 files changed

+102
-54
lines changed

11 files changed

+102
-54
lines changed

sql/analyzer/costed_index_scan.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ func (c *indexCoster) getConstAndNullFilters(filters sql.FastIntSet) (sql.FastIn
652652
switch e.(type) {
653653
case *expression.Equals:
654654
isConst.Add(i)
655-
case *expression.IsNull:
655+
case sql.IsNullExpression:
656656
isNull.Add(i)
657657
case *expression.NullSafeEquals:
658658
isConst.Add(i)
@@ -1513,14 +1513,20 @@ func IndexLeafChildren(e sql.Expression) (IndexScanOp, sql.Expression, sql.Expre
15131513
left = e.Left()
15141514
right = e.Right()
15151515
op = IndexScanOpLte
1516-
case *expression.IsNull:
1517-
left = e.Child
1516+
case sql.IsNullExpression:
1517+
left = e.Children()[0]
15181518
op = IndexScanOpIsNull
1519+
case sql.IsNotNullExpression:
1520+
left = e.Children()[0]
1521+
op = IndexScanOpIsNotNull
15191522
case *expression.Not:
15201523
switch e := e.Child.(type) {
1521-
case *expression.IsNull:
1522-
left = e.Child
1524+
case sql.IsNullExpression:
1525+
left = e.Children()[0]
15231526
op = IndexScanOpIsNotNull
1527+
// TODO: In Postgres, Not(IS NULL) is valid, but doesn't necessarily always mean the
1528+
// same thing as IS NOT NULL, particularly for the case of records or composite
1529+
// values.
15241530
case *expression.Equals:
15251531
left = e.Left()
15261532
right = e.Right()

sql/analyzer/indexed_joins.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ func convertAntiToLeftJoin(m *memo.Memo) error {
567567
// drop null projected columns on right table
568568
nullFilters := make([]sql.Expression, len(nullify))
569569
for i, e := range nullify {
570-
nullFilters[i] = expression.NewIsNull(e)
570+
nullFilters[i] = expression.DefaultExpressionFactory.NewIsNull(e)
571571
}
572572

573573
filterGrp := m.MemoizeFilter(nil, joinGrp, nullFilters)
@@ -1412,7 +1412,7 @@ func isWeaklyMonotonic(e sql.Expression) bool {
14121412
}
14131413
return false
14141414
case *expression.Equals, *expression.NullSafeEquals, *expression.Literal, *expression.GetField,
1415-
*expression.Tuple, *expression.IsNull, *expression.BindVar:
1415+
*expression.Tuple, *expression.BindVar, sql.IsNullExpression, sql.IsNotNullExpression:
14161416
return false
14171417
default:
14181418
if e, ok := e.(expression.Equality); ok && e.RepresentsEquality() {

sql/analyzer/optimization_rules.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func expressionSources(expr sql.Expression) (sql.FastIntSet, bool) {
172172
switch e := e.(type) {
173173
case *expression.GetField:
174174
tables.Add(int(e.TableId()))
175-
case *expression.IsNull:
175+
case sql.IsNullExpression, sql.IsNotNullExpression:
176176
nullRejecting = false
177177
case *expression.NullSafeEquals:
178178
nullRejecting = false
@@ -188,7 +188,7 @@ func expressionSources(expr sql.Expression) (sql.FastIntSet, bool) {
188188
switch e := innerExpr.(type) {
189189
case *expression.GetField:
190190
tables.Add(int(e.TableId()))
191-
case *expression.IsNull:
191+
case sql.IsNullExpression, sql.IsNotNullExpression:
192192
nullRejecting = false
193193
case *expression.NullSafeEquals:
194194
nullRejecting = false

sql/core.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,19 @@ type NonDeterministicExpression interface {
7575
IsNonDeterministic() bool
7676
}
7777

78+
// IsNullExpression indicates that this expression tests for IS NULL.
79+
type IsNullExpression interface {
80+
Expression
81+
IsNullExpression() bool
82+
}
83+
84+
// IsNotNullExpression indicates that this expression tests for IS NOT NULL. Note that in some cases in some
85+
// database engines, such as records in Postgres, IS NOT NULL is not identical to NOT(IS NULL).
86+
type IsNotNullExpression interface {
87+
Expression
88+
IsNotNullExpression() bool
89+
}
90+
7891
// Node is a node in the execution plan tree.
7992
type Node interface {
8093
Resolvable

sql/expression/expr-factory.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package expression
16+
17+
import "github.com/dolthub/go-mysql-server/sql"
18+
19+
// ExpressionFactory allows integrators to provide custom implementations of
20+
// expressions, such as IS NULL and IS NOT NULL.
21+
type ExpressionFactory interface {
22+
// NewIsNull returns a sql.Expression implementation that handles
23+
// the IS NULL expression.
24+
NewIsNull(e sql.Expression) sql.Expression
25+
// NewIsNotNull returns a sql.Expression implementation that handles
26+
// the IS NOT NULL expression.
27+
NewIsNotNull(e sql.Expression) sql.Expression
28+
}
29+
30+
// DefaultExpressionFactory is the ExpressionFactory used when the analyzer
31+
// needs to create new expressions during analysis, such as IS NULL or
32+
// IS NOT NULL. Integrators can swap in their own implementation if they need
33+
// to customize the existing logic for these expressions.
34+
var DefaultExpressionFactory ExpressionFactory = MySqlExpressionFactory{}
35+
36+
// MySqlExpressionFactory is the ExpressionFactory that creates expressions
37+
// that follow MySQL's logic.
38+
type MySqlExpressionFactory struct{}
39+
40+
var _ ExpressionFactory = (*MySqlExpressionFactory)(nil)
41+
42+
// NewIsNull implements the ExpressionFactory interface.
43+
func (m MySqlExpressionFactory) NewIsNull(e sql.Expression) sql.Expression {
44+
return NewIsNull(e)
45+
}
46+
47+
// NewIsNotNull implements the ExpressionFactory interface.
48+
func (m MySqlExpressionFactory) NewIsNotNull(e sql.Expression) sql.Expression {
49+
return NewNot(NewIsNull(e))
50+
}

sql/expression/filter-range.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,24 @@ func NewRangeFilterExpr(exprs []sql.Expression, ranges []sql.MySQLRange) (sql.Ex
4848
case sql.RangeType_All:
4949
rangeColumnExpr = NewEquals(NewLiteral(1, types.Int8), NewLiteral(1, types.Int8))
5050
case sql.RangeType_EqualNull:
51-
rangeColumnExpr = NewIsNull(exprs[i])
51+
rangeColumnExpr = DefaultExpressionFactory.NewIsNull(exprs[i])
5252
case sql.RangeType_GreaterThan:
5353
if sql.MySQLRangeCutIsBinding(rce.LowerBound) {
5454
rangeColumnExpr = NewGreaterThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote()))
5555
} else {
56-
rangeColumnExpr = NewNot(NewIsNull(exprs[i]))
56+
rangeColumnExpr = DefaultExpressionFactory.NewIsNotNull(exprs[i])
5757
}
5858
case sql.RangeType_GreaterOrEqual:
5959
rangeColumnExpr = NewGreaterThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.LowerBound), rce.Typ.Promote()))
6060
case sql.RangeType_LessThanOrNull:
6161
rangeColumnExpr = JoinOr(
6262
NewLessThan(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
63-
NewIsNull(exprs[i]),
63+
DefaultExpressionFactory.NewIsNull(exprs[i]),
6464
)
6565
case sql.RangeType_LessOrEqualOrNull:
6666
rangeColumnExpr = JoinOr(
6767
NewLessThanOrEqual(exprs[i], NewLiteral(sql.GetMySQLRangeCutKey(rce.UpperBound), rce.Typ.Promote())),
68-
NewIsNull(exprs[i]),
68+
DefaultExpressionFactory.NewIsNull(exprs[i]),
6969
)
7070
case sql.RangeType_ClosedClosed:
7171
rangeColumnExpr = JoinAnd(

sql/expression/isnull.go

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,19 @@ type IsNull struct {
2626

2727
var _ sql.Expression = (*IsNull)(nil)
2828
var _ sql.CollationCoercible = (*IsNull)(nil)
29+
var _ sql.IsNullExpression = (*IsNull)(nil)
2930

3031
// NewIsNull creates a new IsNull expression.
3132
func NewIsNull(child sql.Expression) *IsNull {
3233
return &IsNull{UnaryExpression{child}}
3334
}
3435

36+
// IsNullExpression implements the sql.IsNullExpression interface. This function exsists primarily
37+
// to ensure the IsNullExpression interface has a unique signature.
38+
func (e *IsNull) IsNullExpression() bool {
39+
return true
40+
}
41+
3542
// Type implements the Expression interface.
3643
func (e *IsNull) Type() sql.Type {
3744
return types.Boolean
@@ -53,18 +60,6 @@ func (e *IsNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
5360
if err != nil {
5461
return nil, err
5562
}
56-
57-
// Slices of typed values (e.g. Record and Composite types in Postgres) evaluate
58-
// to NULL if all of their entries are NULL.
59-
if tupleValue, ok := v.([]types.TupleValue); ok {
60-
for _, typedValue := range tupleValue {
61-
if typedValue.Value != nil {
62-
return false, nil
63-
}
64-
}
65-
return true, nil
66-
}
67-
6863
return v == nil, nil
6964
}
7065

sql/memo/rel_props.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,18 @@ func (p *relProps) populateFds() {
285285
}
286286
}
287287
case *expression.Not:
288-
child, ok := f.Child.(*expression.IsNull)
288+
child, ok := f.Child.(sql.IsNullExpression)
289289
if ok {
290-
col, ok := child.Child.(*expression.GetField)
290+
col, ok := child.Children()[0].(*expression.GetField)
291291
if ok {
292292
notNull.Add(col.Id())
293293
}
294294
}
295+
case sql.IsNotNullExpression:
296+
col, ok := f.Children()[0].(*expression.GetField)
297+
if ok {
298+
notNull.Add(col.Id())
299+
}
295300
}
296301
}
297302
fds = sql.NewFilterFDs(rel.Child.RelProps.FuncDeps(), notNull, constant, equiv)

sql/plan/join.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,9 +531,12 @@ func NewSemiJoin(left, right sql.Node, cond sql.Expression) *JoinNode {
531531
// IsNullRejecting returns whether the expression always returns false for
532532
// nil inputs.
533533
func IsNullRejecting(e sql.Expression) bool {
534+
// Note that InspectExpr will stop inspecting expressions in the
535+
// expression tree when true is returned, so we invert that return
536+
// value from InspectExpr to return the correct null rejecting value.
534537
return !transform.InspectExpr(e, func(e sql.Expression) bool {
535538
switch e.(type) {
536-
case *expression.NullSafeEquals, *expression.IsNull:
539+
case sql.IsNullExpression, sql.IsNotNullExpression, *expression.NullSafeEquals:
537540
return true
538541
default:
539542
return false

sql/planbuilder/scalar.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,10 +747,10 @@ func (b *Builder) buildIsExprToExpression(inScope *scope, c *ast.IsExpr) sql.Exp
747747
e := b.buildScalar(inScope, c.Expr)
748748
switch strings.ToLower(c.Operator) {
749749
case ast.IsNullStr:
750-
return expression.NewIsNull(e)
750+
return expression.DefaultExpressionFactory.NewIsNull(e)
751751
case ast.IsNotNullStr:
752752
b.qFlags.Set(sql.QFlgNotExpr)
753-
return expression.NewNot(expression.NewIsNull(e))
753+
return expression.DefaultExpressionFactory.NewIsNotNull(e)
754754
case ast.IsTrueStr:
755755
return expression.NewIsTrue(e)
756756
case ast.IsFalseStr:

0 commit comments

Comments
 (0)