Skip to content

Commit c857474

Browse files
author
James Cor
committed
compare and convert system types properly
1 parent d97de81 commit c857474

File tree

6 files changed

+249
-73
lines changed

6 files changed

+249
-73
lines changed

enginetest/queries/script_queries.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7499,6 +7499,22 @@ where
74997499
},
75007500
},
75017501
},
7502+
{
7503+
Name: "coalesce with system types",
7504+
SetUpScript: []string{
7505+
"create table t as select @@admin_port as port1, @@port as port2, COALESCE(@@admin_port, @@port) as\n port3;",
7506+
},
7507+
Assertions: []ScriptTestAssertion{
7508+
{
7509+
Query: "describe t;",
7510+
Expected: []sql.Row{
7511+
{"port1", "bigint", "NO", "", nil, ""},
7512+
{"port2", "bigint", "NO", "", nil, ""},
7513+
{"port3", "bigint", "NO", "", nil, ""},
7514+
},
7515+
},
7516+
},
7517+
},
75027518
}
75037519

75047520
var SpatialScriptTests = []ScriptTest{

sql/analyzer/resolve_create_select.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.
3131
for i, col := range mergedSchema {
3232
tempCol := *col
3333
tempCol.Source = ct.Name()
34+
// replace system variable types with their underlying types
35+
if sysType, isSysTyp := tempCol.Type.(sql.SystemVariableType); isSysTyp {
36+
tempCol.Type = sysType.UnderlyingType()
37+
}
3438
newSch[i] = &tempCol
3539
}
3640

sql/expression/function/coalesce.go

Lines changed: 59 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ package function
1616

1717
import (
1818
"fmt"
19-
"strings"
19+
"github.com/dolthub/go-mysql-server/sql/expression"
20+
"strings"
2021

2122
"github.com/dolthub/go-mysql-server/sql"
22-
"github.com/dolthub/go-mysql-server/sql/expression"
23-
"github.com/dolthub/go-mysql-server/sql/types"
23+
"github.com/dolthub/go-mysql-server/sql/types"
2424
)
2525

2626
// Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values.
@@ -53,61 +53,73 @@ func (c *Coalesce) Description() string {
5353
// Type implements the sql.Expression interface.
5454
// The return type of Type() is the aggregated type of the argument types.
5555
func (c *Coalesce) Type() sql.Type {
56-
typ := types.Null
57-
for _, arg := range c.args {
56+
retType := types.Null
57+
for i, arg := range c.args {
5858
if arg == nil {
5959
continue
6060
}
61-
t := arg.Type()
61+
argType := arg.Type()
62+
if sysVarType, ok := argType.(sql.SystemVariableType); ok {
63+
argType = sysVarType.UnderlyingType()
64+
}
65+
if i == 0 {
66+
retType = argType
67+
continue
68+
}
69+
if argType == nil || argType == types.Null {
70+
continue
71+
}
72+
if retType.Equals(argType) {
73+
continue
74+
}
75+
6276
// special case for signed and unsigned integers
63-
if (types.IsSigned(typ) && types.IsUnsigned(t)) || (types.IsUnsigned(typ) && types.IsSigned(t)) {
64-
typ = types.MustCreateDecimalType(20, 0)
77+
if (types.IsSigned(retType) && types.IsUnsigned(argType)) || (types.IsUnsigned(retType) && types.IsSigned(argType)) {
78+
retType = types.MustCreateDecimalType(20, 0)
6579
continue
6680
}
6781

68-
if t != nil && t != types.Null {
69-
convType := expression.GetConvertToType(typ, t)
70-
switch convType {
71-
case expression.ConvertToChar:
72-
// special case for float64s
73-
if (t == types.Float64 || typ == types.Float64) && !types.IsText(t) && !types.IsText(typ) {
74-
typ = types.Float64
75-
continue
76-
}
77-
// Can't get any larger than this
78-
return types.LongText
79-
case expression.ConvertToDecimal:
80-
if typ == types.Float64 || t == types.Float64 {
81-
typ = types.Float64
82-
} else if types.IsDecimal(t) {
83-
typ = t
84-
} else if !types.IsDecimal(typ) {
85-
typ = types.MustCreateDecimalType(10, 0)
86-
}
87-
case expression.ConvertToUnsigned:
88-
if typ == types.Uint64 || t == types.Uint64 {
89-
typ = types.Uint64
90-
} else {
91-
typ = types.Uint32
92-
}
93-
case expression.ConvertToSigned:
94-
if typ == types.Int64 || t == types.Int64 {
95-
typ = types.Int64
96-
} else {
97-
typ = types.Int32
98-
}
99-
case expression.ConvertToFloat:
100-
if typ == types.Float64 || t == types.Float64 {
101-
typ = types.Float64
102-
} else {
103-
typ = types.Float32
104-
}
105-
default:
82+
convType := expression.GetConvertToType(retType, argType)
83+
switch convType {
84+
case expression.ConvertToChar:
85+
// special case for float64s
86+
if (argType == types.Float64 || retType == types.Float64) && !types.IsText(argType) && !types.IsText(retType) {
87+
retType = types.Float64
88+
continue
89+
}
90+
// Can't get any larger than this
91+
return types.LongText
92+
case expression.ConvertToDecimal:
93+
if retType == types.Float64 || argType == types.Float64 {
94+
retType = types.Float64
95+
} else if types.IsDecimal(argType) {
96+
retType = argType
97+
} else if !types.IsDecimal(retType) {
98+
retType = types.MustCreateDecimalType(10, 0)
99+
}
100+
case expression.ConvertToUnsigned:
101+
if retType == types.Uint64 || argType == types.Uint64 {
102+
retType = types.Uint64
103+
} else {
104+
retType = types.Uint32
105+
}
106+
case expression.ConvertToSigned:
107+
if retType == types.Int64 || argType == types.Int64 {
108+
retType = types.Int64
109+
} else {
110+
retType = types.Int32
111+
}
112+
case expression.ConvertToFloat:
113+
if retType == types.Float64 || argType == types.Float64 {
114+
retType = types.Float64
115+
} else {
116+
retType = types.Float32
106117
}
118+
default:
107119
}
108120
}
109121

110-
return typ
122+
return retType
111123
}
112124

113125
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/expression/function/coalesce_test.go

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
package function
1616

1717
import (
18-
"testing"
19-
2018
"github.com/shopspring/decimal"
21-
"github.com/stretchr/testify/require"
19+
"testing"
20+
21+
"github.com/stretchr/testify/require"
2222

2323
"github.com/dolthub/go-mysql-server/sql"
2424
"github.com/dolthub/go-mysql-server/sql/expression"
@@ -152,17 +152,86 @@ func TestCoalesce(t *testing.T) {
152152
typ: types.Float64,
153153
nullable: false,
154154
},
155+
{
156+
name: "coalesce(sysInt, sysInt)",
157+
input: []sql.Expression{
158+
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
159+
expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)),
160+
},
161+
expected: 1,
162+
typ: types.Int64,
163+
nullable: false,
164+
},
165+
{
166+
name: "coalesce(sysInt, sysUint)",
167+
input: []sql.Expression{
168+
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
169+
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
170+
},
171+
expected: 1,
172+
typ: types.MustCreateDecimalType(20, 0),
173+
nullable: false,
174+
},
175+
{
176+
name: "coalesce(sysUint, sysUint)",
177+
input: []sql.Expression{
178+
expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)),
179+
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
180+
},
181+
expected: 1,
182+
typ: types.Uint64,
183+
nullable: false,
184+
},
185+
{
186+
name: "coalesce(sysDouble, sysDouble)",
187+
input: []sql.Expression{
188+
expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)),
189+
expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)),
190+
},
191+
expected: 1.0,
192+
typ: types.Float64,
193+
nullable: false,
194+
},
195+
{
196+
name: "coalesce(sysText)",
197+
input: []sql.Expression{
198+
expression.NewLiteral("abc", types.NewSystemStringType("str1")),
199+
},
200+
expected: "abc",
201+
typ: types.LongText,
202+
nullable: false,
203+
},
204+
{
205+
name: "coalesce(sysEnum)",
206+
input: []sql.Expression{
207+
expression.NewLiteral("abc", types.NewSystemEnumType("str1")),
208+
},
209+
expected: "abc",
210+
typ: types.EnumType{},
211+
nullable: false,
212+
},
213+
{
214+
name: "coalesce(sysSet)",
215+
input: []sql.Expression{
216+
expression.NewLiteral("abc", types.NewSystemSetType("str1", "abc")),
217+
},
218+
expected: "abc",
219+
typ: types.MustCreateSetType([]string{"abc"}, sql.Collation_Default),
220+
nullable: false,
221+
},
155222
}
156223

157224
for _, tt := range testCases {
158-
c, err := NewCoalesce(tt.input...)
159-
require.NoError(t, err)
225+
t.Run(tt.name, func(t *testing.T) {
226+
c, err := NewCoalesce(tt.input...)
227+
require.NoError(t, err)
160228

161-
require.Equal(t, tt.typ, c.Type())
162-
require.Equal(t, tt.nullable, c.IsNullable())
163-
v, err := c.Eval(sql.NewEmptyContext(), nil)
164-
require.NoError(t, err)
165-
require.Equal(t, tt.expected, v)
229+
require.Equal(t, tt.typ, c.Type())
230+
require.Equal(t, tt.nullable, c.IsNullable())
231+
v, err := c.Eval(sql.NewEmptyContext(), nil)
232+
require.NoError(t, err)
233+
require.Equal(t, tt.expected, v)
234+
})
166235
}
167236
}
168237

sql/types/typecheck.go

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ func IsNull(ex sql.Expression) bool {
9191

9292
// IsNumber checks if t is a number type
9393
func IsNumber(t sql.Type) bool {
94-
switch t.(type) {
95-
case NumberTypeImpl_, DecimalType_, BitType_, YearType_, SystemBoolType:
94+
switch typ := t.(type) {
95+
case sql.SystemVariableType:
96+
return IsNumber(typ.UnderlyingType())
97+
case NumberTypeImpl_, DecimalType_, BitType_, YearType_:
9698
return true
9799
default:
98100
return false
@@ -101,23 +103,25 @@ func IsNumber(t sql.Type) bool {
101103

102104
// IsSigned checks if t is a signed type.
103105
func IsSigned(t sql.Type) bool {
104-
// systemBoolType is Int8
105-
if _, ok := t.(SystemBoolType); ok {
106-
return true
106+
if svt, ok := t.(sql.SystemVariableType); ok {
107+
t = svt.UnderlyingType()
107108
}
108109
return t == Int8 || t == Int16 || t == Int24 || t == Int32 || t == Int64 || t == Boolean
109110
}
110111

111112
// IsText checks if t is a CHAR, VARCHAR, TEXT, BINARY, VARBINARY, or BLOB (including TEXT and BLOB variants).
112113
func IsText(t sql.Type) bool {
113-
if _, ok := t.(StringType); ok {
114-
return ok
115-
}
116-
if extendedType, ok := t.(ExtendedType); ok {
117-
_, isString := extendedType.Zero().(string)
114+
switch typ := t.(type) {
115+
case sql.SystemVariableType:
116+
return IsText(typ.UnderlyingType())
117+
case StringType:
118+
return true
119+
case ExtendedType:
120+
_, isString := typ.Zero().(string)
118121
return isString
122+
default:
123+
return false
119124
}
120-
return false
121125
}
122126

123127
// IsTextBlob checks if t is one of the TEXTs or BLOBs.
@@ -178,14 +182,26 @@ func IsTimestampType(t sql.Type) bool {
178182

179183
// IsEnum checks if t is a enum
180184
func IsEnum(t sql.Type) bool {
181-
_, ok := t.(EnumType)
182-
return ok
185+
switch typ := t.(type) {
186+
case sql.SystemVariableType:
187+
return IsEnum(typ.UnderlyingType())
188+
case EnumType:
189+
return true
190+
default:
191+
return false
192+
}
183193
}
184194

185195
// IsSet checks if t is a set
186196
func IsSet(t sql.Type) bool {
187-
_, ok := t.(SetType)
188-
return ok
197+
switch typ := t.(type) {
198+
case sql.SystemVariableType:
199+
return IsSet(typ.UnderlyingType())
200+
case SetType:
201+
return true
202+
default:
203+
return false
204+
}
189205
}
190206

191207
// IsTuple checks if t is a tuple type.
@@ -201,7 +217,6 @@ func IsUnsigned(t sql.Type) bool {
201217
if svt, ok := t.(sql.SystemVariableType); ok {
202218
t = svt.UnderlyingType()
203219
}
204-
205220
return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64
206221
}
207222

0 commit comments

Comments
 (0)