Skip to content

Commit 54bd6d6

Browse files
authored
compare and convert system types properly (#2700)
1 parent d97de81 commit 54bd6d6

File tree

6 files changed

+245
-68
lines changed

6 files changed

+245
-68
lines changed

enginetest/queries/script_queries.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7499,6 +7499,22 @@ where
74997499
},
75007500
},
75017501
},
7502+
{
7503+
Name: "coalesce with system types",
7504+
SetUpScript: []string{
7505+
"create table t as select @@admin_port as port1, @@port as port2, COALESCE(@@admin_port, @@port) as\n port3;",
7506+
},
7507+
Assertions: []ScriptTestAssertion{
7508+
{
7509+
Query: "describe t;",
7510+
Expected: []sql.Row{
7511+
{"port1", "bigint", "NO", "", nil, ""},
7512+
{"port2", "bigint", "NO", "", nil, ""},
7513+
{"port3", "bigint", "NO", "", nil, ""},
7514+
},
7515+
},
7516+
},
7517+
},
75027518
}
75037519

75047520
var SpatialScriptTests = []ScriptTest{

sql/analyzer/resolve_create_select.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.
3131
for i, col := range mergedSchema {
3232
tempCol := *col
3333
tempCol.Source = ct.Name()
34+
// replace system variable types with their underlying types
35+
if sysType, isSysTyp := tempCol.Type.(sql.SystemVariableType); isSysTyp {
36+
tempCol.Type = sysType.UnderlyingType()
37+
}
3438
newSch[i] = &tempCol
3539
}
3640

sql/expression/function/coalesce.go

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ import (
1818
"fmt"
1919
"strings"
2020

21-
"github.com/dolthub/go-mysql-server/sql"
2221
"github.com/dolthub/go-mysql-server/sql/expression"
22+
23+
"github.com/dolthub/go-mysql-server/sql"
2324
"github.com/dolthub/go-mysql-server/sql/types"
2425
)
2526

@@ -53,61 +54,73 @@ func (c *Coalesce) Description() string {
5354
// Type implements the sql.Expression interface.
5455
// The return type of Type() is the aggregated type of the argument types.
5556
func (c *Coalesce) Type() sql.Type {
56-
typ := types.Null
57-
for _, arg := range c.args {
57+
retType := types.Null
58+
for i, arg := range c.args {
5859
if arg == nil {
5960
continue
6061
}
61-
t := arg.Type()
62+
argType := arg.Type()
63+
if sysVarType, ok := argType.(sql.SystemVariableType); ok {
64+
argType = sysVarType.UnderlyingType()
65+
}
66+
if i == 0 {
67+
retType = argType
68+
continue
69+
}
70+
if argType == nil || argType == types.Null {
71+
continue
72+
}
73+
if retType.Equals(argType) {
74+
continue
75+
}
76+
6277
// special case for signed and unsigned integers
63-
if (types.IsSigned(typ) && types.IsUnsigned(t)) || (types.IsUnsigned(typ) && types.IsSigned(t)) {
64-
typ = types.MustCreateDecimalType(20, 0)
78+
if (types.IsSigned(retType) && types.IsUnsigned(argType)) || (types.IsUnsigned(retType) && types.IsSigned(argType)) {
79+
retType = types.MustCreateDecimalType(20, 0)
6580
continue
6681
}
6782

68-
if t != nil && t != types.Null {
69-
convType := expression.GetConvertToType(typ, t)
70-
switch convType {
71-
case expression.ConvertToChar:
72-
// special case for float64s
73-
if (t == types.Float64 || typ == types.Float64) && !types.IsText(t) && !types.IsText(typ) {
74-
typ = types.Float64
75-
continue
76-
}
77-
// Can't get any larger than this
78-
return types.LongText
79-
case expression.ConvertToDecimal:
80-
if typ == types.Float64 || t == types.Float64 {
81-
typ = types.Float64
82-
} else if types.IsDecimal(t) {
83-
typ = t
84-
} else if !types.IsDecimal(typ) {
85-
typ = types.MustCreateDecimalType(10, 0)
86-
}
87-
case expression.ConvertToUnsigned:
88-
if typ == types.Uint64 || t == types.Uint64 {
89-
typ = types.Uint64
90-
} else {
91-
typ = types.Uint32
92-
}
93-
case expression.ConvertToSigned:
94-
if typ == types.Int64 || t == types.Int64 {
95-
typ = types.Int64
96-
} else {
97-
typ = types.Int32
98-
}
99-
case expression.ConvertToFloat:
100-
if typ == types.Float64 || t == types.Float64 {
101-
typ = types.Float64
102-
} else {
103-
typ = types.Float32
104-
}
105-
default:
83+
convType := expression.GetConvertToType(retType, argType)
84+
switch convType {
85+
case expression.ConvertToChar:
86+
// special case for float64s
87+
if (argType == types.Float64 || retType == types.Float64) && !types.IsText(argType) && !types.IsText(retType) {
88+
retType = types.Float64
89+
continue
90+
}
91+
// Can't get any larger than this
92+
return types.LongText
93+
case expression.ConvertToDecimal:
94+
if retType == types.Float64 || argType == types.Float64 {
95+
retType = types.Float64
96+
} else if types.IsDecimal(argType) {
97+
retType = argType
98+
} else if !types.IsDecimal(retType) {
99+
retType = types.MustCreateDecimalType(10, 0)
100+
}
101+
case expression.ConvertToUnsigned:
102+
if retType == types.Uint64 || argType == types.Uint64 {
103+
retType = types.Uint64
104+
} else {
105+
retType = types.Uint32
106+
}
107+
case expression.ConvertToSigned:
108+
if retType == types.Int64 || argType == types.Int64 {
109+
retType = types.Int64
110+
} else {
111+
retType = types.Int32
112+
}
113+
case expression.ConvertToFloat:
114+
if retType == types.Float64 || argType == types.Float64 {
115+
retType = types.Float64
116+
} else {
117+
retType = types.Float32
106118
}
119+
default:
107120
}
108121
}
109122

110-
return typ
123+
return retType
111124
}
112125

113126
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/expression/function/coalesce_test.go

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,86 @@ func TestCoalesce(t *testing.T) {
152152
typ: types.Float64,
153153
nullable: false,
154154
},
155+
{
156+
name: "coalesce(sysInt, sysInt)",
157+
input: []sql.Expression{
158+
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
159+
expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)),
160+
},
161+
expected: 1,
162+
typ: types.Int64,
163+
nullable: false,
164+
},
165+
{
166+
name: "coalesce(sysInt, sysUint)",
167+
input: []sql.Expression{
168+
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
169+
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
170+
},
171+
expected: 1,
172+
typ: types.MustCreateDecimalType(20, 0),
173+
nullable: false,
174+
},
175+
{
176+
name: "coalesce(sysUint, sysUint)",
177+
input: []sql.Expression{
178+
expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)),
179+
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
180+
},
181+
expected: 1,
182+
typ: types.Uint64,
183+
nullable: false,
184+
},
185+
{
186+
name: "coalesce(sysDouble, sysDouble)",
187+
input: []sql.Expression{
188+
expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)),
189+
expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)),
190+
},
191+
expected: 1.0,
192+
typ: types.Float64,
193+
nullable: false,
194+
},
195+
{
196+
name: "coalesce(sysText)",
197+
input: []sql.Expression{
198+
expression.NewLiteral("abc", types.NewSystemStringType("str1")),
199+
},
200+
expected: "abc",
201+
typ: types.LongText,
202+
nullable: false,
203+
},
204+
{
205+
name: "coalesce(sysEnum)",
206+
input: []sql.Expression{
207+
expression.NewLiteral("abc", types.NewSystemEnumType("str1")),
208+
},
209+
expected: "abc",
210+
typ: types.EnumType{},
211+
nullable: false,
212+
},
213+
{
214+
name: "coalesce(sysSet)",
215+
input: []sql.Expression{
216+
expression.NewLiteral("abc", types.NewSystemSetType("str1", "abc")),
217+
},
218+
expected: "abc",
219+
typ: types.MustCreateSetType([]string{"abc"}, sql.Collation_Default),
220+
nullable: false,
221+
},
155222
}
156223

157224
for _, tt := range testCases {
158-
c, err := NewCoalesce(tt.input...)
159-
require.NoError(t, err)
225+
t.Run(tt.name, func(t *testing.T) {
226+
c, err := NewCoalesce(tt.input...)
227+
require.NoError(t, err)
160228

161-
require.Equal(t, tt.typ, c.Type())
162-
require.Equal(t, tt.nullable, c.IsNullable())
163-
v, err := c.Eval(sql.NewEmptyContext(), nil)
164-
require.NoError(t, err)
165-
require.Equal(t, tt.expected, v)
229+
require.Equal(t, tt.typ, c.Type())
230+
require.Equal(t, tt.nullable, c.IsNullable())
231+
v, err := c.Eval(sql.NewEmptyContext(), nil)
232+
require.NoError(t, err)
233+
require.Equal(t, tt.expected, v)
234+
})
166235
}
167236
}
168237

sql/types/typecheck.go

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ func IsNull(ex sql.Expression) bool {
9191

9292
// IsNumber checks if t is a number type
9393
func IsNumber(t sql.Type) bool {
94-
switch t.(type) {
95-
case NumberTypeImpl_, DecimalType_, BitType_, YearType_, SystemBoolType:
94+
switch typ := t.(type) {
95+
case sql.SystemVariableType:
96+
return IsNumber(typ.UnderlyingType())
97+
case NumberTypeImpl_, DecimalType_, BitType_, YearType_:
9698
return true
9799
default:
98100
return false
@@ -101,23 +103,25 @@ func IsNumber(t sql.Type) bool {
101103

102104
// IsSigned checks if t is a signed type.
103105
func IsSigned(t sql.Type) bool {
104-
// systemBoolType is Int8
105-
if _, ok := t.(SystemBoolType); ok {
106-
return true
106+
if svt, ok := t.(sql.SystemVariableType); ok {
107+
t = svt.UnderlyingType()
107108
}
108109
return t == Int8 || t == Int16 || t == Int24 || t == Int32 || t == Int64 || t == Boolean
109110
}
110111

111112
// IsText checks if t is a CHAR, VARCHAR, TEXT, BINARY, VARBINARY, or BLOB (including TEXT and BLOB variants).
112113
func IsText(t sql.Type) bool {
113-
if _, ok := t.(StringType); ok {
114-
return ok
115-
}
116-
if extendedType, ok := t.(ExtendedType); ok {
117-
_, isString := extendedType.Zero().(string)
114+
switch typ := t.(type) {
115+
case sql.SystemVariableType:
116+
return IsText(typ.UnderlyingType())
117+
case StringType:
118+
return true
119+
case ExtendedType:
120+
_, isString := typ.Zero().(string)
118121
return isString
122+
default:
123+
return false
119124
}
120-
return false
121125
}
122126

123127
// IsTextBlob checks if t is one of the TEXTs or BLOBs.
@@ -178,14 +182,26 @@ func IsTimestampType(t sql.Type) bool {
178182

179183
// IsEnum checks if t is a enum
180184
func IsEnum(t sql.Type) bool {
181-
_, ok := t.(EnumType)
182-
return ok
185+
switch typ := t.(type) {
186+
case sql.SystemVariableType:
187+
return IsEnum(typ.UnderlyingType())
188+
case EnumType:
189+
return true
190+
default:
191+
return false
192+
}
183193
}
184194

185195
// IsSet checks if t is a set
186196
func IsSet(t sql.Type) bool {
187-
_, ok := t.(SetType)
188-
return ok
197+
switch typ := t.(type) {
198+
case sql.SystemVariableType:
199+
return IsSet(typ.UnderlyingType())
200+
case SetType:
201+
return true
202+
default:
203+
return false
204+
}
189205
}
190206

191207
// IsTuple checks if t is a tuple type.
@@ -201,7 +217,6 @@ func IsUnsigned(t sql.Type) bool {
201217
if svt, ok := t.(sql.SystemVariableType); ok {
202218
t = svt.UnderlyingType()
203219
}
204-
205220
return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64
206221
}
207222

0 commit comments

Comments
 (0)