Skip to content

Commit 1276357

Browse files
committed
Abstracted the null type check
1 parent 59c9bc9 commit 1276357

File tree

6 files changed

+22
-4
lines changed

6 files changed

+22
-4
lines changed

sql/expression/case.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression
4545

4646
// Type implements the sql.Expression interface.
4747
func (c *Case) Type() sql.Type {
48-
curr := types.Null
48+
var curr sql.Type
49+
curr = types.Null
4950
for _, b := range c.Branches {
5051
curr = types.GeneralizeTypes(curr, b.Value.Type())
5152
}

sql/expression/function/coalesce.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ func (c *Coalesce) Type() sql.Type {
5858
if c.typ != nil {
5959
return c.typ
6060
}
61-
retType := types.Null
61+
62+
var retType sql.Type
63+
retType = types.Null
6264
for i, arg := range c.args {
6365
if arg == nil {
6466
continue

sql/type.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ type Type interface {
104104
// NullType represents the type of NULL values
105105
type NullType interface {
106106
Type
107+
108+
// IsNullType is a marker interface for types that represent NULL values.
109+
IsNullType() bool
107110
}
108111

109112
// DeferredType is a placeholder for prepared statements

sql/types/conversion.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,10 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
642642
return a
643643
}
644644

645-
if a == Null {
645+
if IsNullType(a) {
646646
return b
647647
}
648-
if b == Null {
648+
if IsNullType(b) {
649649
return a
650650
}
651651

sql/types/null.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ var (
3434

3535
type nullType struct{}
3636

37+
func (t nullType) IsNullType() bool {
38+
return true
39+
}
40+
3741
// Compare implements Type interface. Note that while this returns 0 (equals)
3842
// for ordering purposes, in SQL NULL != NULL.
3943
func (t nullType) Compare(s context.Context, a interface{}, b interface{}) (int, error) {

sql/types/typecheck.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,14 @@ func IsNumber(t sql.Type) bool {
106106
}
107107
}
108108

109+
func IsNullType(t sql.Type) bool {
110+
nt, ok := t.(sql.NullType)
111+
if !ok {
112+
return false
113+
}
114+
return nt.IsNullType()
115+
}
116+
109117
// IsSigned checks if t is a signed type.
110118
func IsSigned(t sql.Type) bool {
111119
if svt, ok := t.(sql.SystemVariableType); ok {

0 commit comments

Comments
 (0)