Skip to content

Commit 710ac92

Browse files
authored
[expression] coalesce shoudl cast return values (#2853)
1 parent debff7a commit 710ac92

File tree

3 files changed

+47
-16
lines changed

3 files changed

+47
-16
lines changed

enginetest/queries/queries.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,26 @@ var QueryTests = []QueryTest{
842842
Query: "SELECT 1 WHERE ((1 IN (NULL * 1)) IS NULL);",
843843
Expected: []sql.Row{{1}},
844844
},
845+
{
846+
Query: "select coalesce(1, 0.0);",
847+
Expected: []sql.Row{{"1"}},
848+
},
849+
{
850+
Query: "select coalesce(1, '0');",
851+
Expected: []sql.Row{{"1"}},
852+
},
853+
{
854+
Query: "select coalesce(1, 'x');",
855+
Expected: []sql.Row{{"1"}},
856+
},
857+
{
858+
Query: "select coalesce(1, 1);",
859+
Expected: []sql.Row{{1}},
860+
},
861+
{
862+
Query: "select coalesce(1, CAST('2017-08-29' AS DATE))",
863+
Expected: []sql.Row{{"1"}},
864+
},
845865
{
846866
Query: "SELECT count(*) from mytable WHERE ((i IN (NULL >= 1)) IS NULL);",
847867
Expected: []sql.Row{{3}},

sql/expression/function/coalesce.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
// Coalesce returns the first non-NULL value in the list, or NULL if there are no non-NULL values.
2828
type Coalesce struct {
2929
args []sql.Expression
30+
typ sql.Type
3031
}
3132

3233
var _ sql.FunctionExpression = (*Coalesce)(nil)
@@ -38,7 +39,7 @@ func NewCoalesce(args ...sql.Expression) (sql.Expression, error) {
3839
return nil, sql.ErrInvalidArgumentNumber.New("COALESCE", "1 or more", 0)
3940
}
4041

41-
return &Coalesce{args}, nil
42+
return &Coalesce{args: args}, nil
4243
}
4344

4445
// FunctionName implements sql.FunctionExpression
@@ -54,6 +55,9 @@ func (c *Coalesce) Description() string {
5455
// Type implements the sql.Expression interface.
5556
// The return type of Type() is the aggregated type of the argument types.
5657
func (c *Coalesce) Type() sql.Type {
58+
if c.typ != nil {
59+
return c.typ
60+
}
5761
retType := types.Null
5862
for i, arg := range c.args {
5963
if arg == nil {
@@ -120,6 +124,7 @@ func (c *Coalesce) Type() sql.Type {
120124
}
121125
}
122126

127+
c.typ = retType
123128
return retType
124129
}
125130

@@ -201,6 +206,12 @@ func (c *Coalesce) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
201206
continue
202207
}
203208

209+
if !types.IsEnum(c.Type()) && !types.IsSet(c.Type()) {
210+
val, _, err = c.Type().Convert(val)
211+
if err != nil {
212+
return nil, err
213+
}
214+
}
204215
return val, nil
205216
}
206217

sql/expression/function/coalesce_test.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func TestCoalesce(t *testing.T) {
4545
expression.NewLiteral(2, types.Int32),
4646
expression.NewLiteral(3, types.Int32),
4747
},
48-
expected: 1,
48+
expected: int32(1),
4949
typ: types.Int32,
5050
nullable: false,
5151
},
@@ -56,7 +56,7 @@ func TestCoalesce(t *testing.T) {
5656
nil,
5757
expression.NewLiteral(3, types.Int32),
5858
},
59-
expected: 3,
59+
expected: int32(3),
6060
typ: types.Int32,
6161
nullable: false,
6262
},
@@ -100,7 +100,7 @@ func TestCoalesce(t *testing.T) {
100100
expression.NewLiteral(decimal.NewFromFloat(2.0), types.MustCreateDecimalType(10, 0)),
101101
expression.NewLiteral("3", types.LongText),
102102
},
103-
expected: 1,
103+
expected: "1",
104104
typ: types.LongText,
105105
nullable: false,
106106
},
@@ -110,7 +110,7 @@ func TestCoalesce(t *testing.T) {
110110
expression.NewLiteral(1, types.Int32),
111111
expression.NewLiteral(2, types.Uint32),
112112
},
113-
expected: 1,
113+
expected: decimal.New(1, 0),
114114
typ: types.MustCreateDecimalType(20, 0),
115115
nullable: false,
116116
},
@@ -120,7 +120,7 @@ func TestCoalesce(t *testing.T) {
120120
expression.NewLiteral(1, types.Int32),
121121
expression.NewLiteral(2, types.Uint32),
122122
},
123-
expected: 1,
123+
expected: decimal.New(1, 0),
124124
typ: types.MustCreateDecimalType(20, 0),
125125
nullable: false,
126126
},
@@ -130,7 +130,7 @@ func TestCoalesce(t *testing.T) {
130130
expression.NewLiteral(1, types.MustCreateDecimalType(10, 0)),
131131
expression.NewLiteral(2, types.Float64),
132132
},
133-
expected: 1,
133+
expected: float64(1),
134134
typ: types.Float64,
135135
nullable: false,
136136
},
@@ -139,7 +139,7 @@ func TestCoalesce(t *testing.T) {
139139
input: []sql.Expression{
140140
expression.NewLiteral(2, types.Float64),
141141
},
142-
expected: 2,
142+
expected: float64(2),
143143
typ: types.Float64,
144144
nullable: false,
145145
},
@@ -148,7 +148,7 @@ func TestCoalesce(t *testing.T) {
148148
input: []sql.Expression{
149149
expression.NewLiteral(1, types.Float64),
150150
},
151-
expected: 1,
151+
expected: float64(1),
152152
typ: types.Float64,
153153
nullable: false,
154154
},
@@ -158,7 +158,7 @@ func TestCoalesce(t *testing.T) {
158158
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
159159
expression.NewLiteral(2, types.NewSystemIntType("int2", 0, 10, false)),
160160
},
161-
expected: 1,
161+
expected: int64(1),
162162
typ: types.Int64,
163163
nullable: false,
164164
},
@@ -168,7 +168,7 @@ func TestCoalesce(t *testing.T) {
168168
expression.NewLiteral(1, types.NewSystemIntType("int1", 0, 10, false)),
169169
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
170170
},
171-
expected: 1,
171+
expected: decimal.New(1, 0),
172172
typ: types.MustCreateDecimalType(20, 0),
173173
nullable: false,
174174
},
@@ -178,7 +178,7 @@ func TestCoalesce(t *testing.T) {
178178
expression.NewLiteral(1, types.NewSystemUintType("int1", 0, 10)),
179179
expression.NewLiteral(2, types.NewSystemUintType("int2", 0, 10)),
180180
},
181-
expected: 1,
181+
expected: uint64(1),
182182
typ: types.Uint64,
183183
nullable: false,
184184
},
@@ -188,7 +188,7 @@ func TestCoalesce(t *testing.T) {
188188
expression.NewLiteral(1.0, types.NewSystemDoubleType("dbl1", 0.0, 10.0)),
189189
expression.NewLiteral(2.0, types.NewSystemDoubleType("dbl2", 0.0, 10.0)),
190190
},
191-
expected: 1.0,
191+
expected: float64(1),
192192
typ: types.Float64,
193193
nullable: false,
194194
},
@@ -249,19 +249,19 @@ func TestComposeCoalasce(t *testing.T) {
249249
require.Equal(t, types.Int32, c2.Type())
250250
v, err = c2.Eval(ctx, nil)
251251
require.NoError(t, err)
252-
require.Equal(t, 1, v)
252+
require.Equal(t, int32(1), v)
253253

254254
c3, err := NewCoalesce(nil, c1, c2)
255255
require.NoError(t, err)
256256
require.Equal(t, types.Int32, c3.Type())
257257
v, err = c3.Eval(ctx, nil)
258258
require.NoError(t, err)
259-
require.Equal(t, 1, v)
259+
require.Equal(t, int32(1), v)
260260

261261
c4, err := NewCoalesce(expression.NewLiteral(nil, types.Null), c1, c2)
262262
require.NoError(t, err)
263263
require.Equal(t, types.Int32, c4.Type())
264264
v, err = c4.Eval(ctx, nil)
265265
require.NoError(t, err)
266-
require.Equal(t, 1, v)
266+
require.Equal(t, int32(1), v)
267267
}

0 commit comments

Comments
 (0)