Skip to content

Commit c6cd9da

Browse files
committed
added tests
1 parent 603321d commit c6cd9da

File tree

5 files changed

+69
-21
lines changed

5 files changed

+69
-21
lines changed

enginetest/queries/script_queries.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8712,6 +8712,25 @@ where
87128712
},
87138713
},
87148714
},
8715+
{
8716+
Name: "tinyint column does not restrict IF or IFNULL output",
8717+
// https://github.com/dolthub/dolt/issues/9321
8718+
SetUpScript: []string{
8719+
"create table t0 (c0 tinyint);",
8720+
"insert into t0 values (null);",
8721+
},
8722+
Assertions: []ScriptTestAssertion{
8723+
{
8724+
Query: "select ifnull(t0.c0, 128) as ref0 from t0",
8725+
Expected: []sql.Row{
8726+
{128},
8727+
},
8728+
},
8729+
{
8730+
Query: "select if(t0.c0 = 1, t0.c0, 128) as ref0 from t0",
8731+
},
8732+
},
8733+
},
87158734
}
87168735

87178736
var SpatialScriptTests = []ScriptTest{

sql/expression/function/ifnull.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,16 @@ func (f *IfNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
5757
return nil, err
5858
}
5959
if left != nil {
60-
return left, nil
60+
left, _, err = f.Type().Convert(ctx, left)
61+
return left, err
6162
}
6263

6364
right, err := f.RightChild.Eval(ctx, row)
6465
if err != nil {
6566
return nil, err
6667
}
67-
return right, nil
68+
right, _, err = f.Type().Convert(ctx, right)
69+
return right, err
6870
}
6971

7072
// Type implements the Expression interface.

sql/expression/function/ifnull_test.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,28 @@ import (
2626

2727
func TestIfNull(t *testing.T) {
2828
testCases := []struct {
29-
expression interface{}
30-
value interface{}
31-
expected interface{}
29+
expression interface{}
30+
expressionType sql.Type
31+
value interface{}
32+
valueType sql.Type
33+
expected interface{}
34+
expectedType sql.Type
3235
}{
33-
{"foo", "bar", "foo"},
34-
{"foo", "foo", "foo"},
35-
{nil, "foo", "foo"},
36-
{"foo", nil, "foo"},
37-
{nil, nil, nil},
38-
{"", nil, ""},
36+
{"foo", types.LongText, "bar", types.LongText, "foo", types.LongText},
37+
{"foo", types.LongText, "foo", types.LongText, "foo", types.LongText},
38+
{nil, types.LongText, "foo", types.LongText, "foo", types.LongText},
39+
{"foo", types.LongText, nil, types.LongText, "foo", types.LongText},
40+
{nil, types.LongText, nil, types.LongText, nil, types.LongText},
41+
{"", types.LongText, nil, types.LongText, "", types.LongText},
42+
{nil, types.Int8, 128, types.Int64, int64(128), types.Int64},
3943
}
4044

41-
f := NewIfNull(
42-
expression.NewGetField(0, types.LongText, "expression", true),
43-
expression.NewGetField(1, types.LongText, "value", true),
44-
)
45-
require.Equal(t, types.LongText, f.Type())
46-
4745
for _, tc := range testCases {
46+
f := NewIfNull(
47+
expression.NewGetField(0, tc.expressionType, "expression", true),
48+
expression.NewGetField(1, tc.valueType, "value", true),
49+
)
50+
require.Equal(t, tc.expectedType, f.Type())
4851
v, err := f.Eval(sql.NewEmptyContext(), sql.NewRow(tc.expression, tc.value))
4952
require.NoError(t, err)
5053
require.Equal(t, tc.expected, v)

sql/types/conversion.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,11 +558,13 @@ func TypesEqual(a, b sql.Type) bool {
558558
// GeneralizeTypes returns the more "general" of two types as defined by
559559
// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html#function_if and
560560
// https://dev.mysql.com/doc/refman/8.4/en/flow-control-functions.html#function_ifnull
561-
// TODO: Currently returns the most general type. Update to match MySQL (pick the more general of the two given types)
561+
// TODO: Currently returns the most general type via Promote(). Update to match MySQL (pick the more general of the two
562+
//
563+
// given types)
562564
func GeneralizeTypes(a, b sql.Type) sql.Type {
563565
if IsText(a) || IsText(b) {
564566
// TODO: handle case-sensitive strings
565-
return Text
567+
return LongText
566568
}
567569

568570
if IsFloat(a) || IsFloat(b) {

sql/types/conversion_test.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func TestColumnTypeToType_Time(t *testing.T) {
119119
}
120120

121121
func TestColumnCharTypes(t *testing.T) {
122-
test := []struct {
122+
tests := []struct {
123123
typ string
124124
len int64
125125
exp sql.Type
@@ -146,7 +146,7 @@ func TestColumnCharTypes(t *testing.T) {
146146
},
147147
}
148148

149-
for _, test := range test {
149+
for _, test := range tests {
150150
t.Run(fmt.Sprintf("%v %v", test.typ, test.exp), func(t *testing.T) {
151151
ct := &sqlparser.ColumnType{
152152
Type: test.typ,
@@ -158,3 +158,25 @@ func TestColumnCharTypes(t *testing.T) {
158158
})
159159
}
160160
}
161+
162+
func TestGeneralizeTypes(t *testing.T) {
163+
tests := []struct {
164+
typeA sql.Type
165+
typeB sql.Type
166+
expected sql.Type
167+
}{
168+
{Text, Text, LongText},
169+
{Text, Float64, LongText},
170+
{Int64, Text, LongText},
171+
{Float32, Float32, Float64},
172+
{Int64, Float64, Float64},
173+
{Int32, Int32, Int64},
174+
{Null, Null, Null},
175+
}
176+
for _, test := range tests {
177+
t.Run(fmt.Sprintf("%v %v %v", test.typeA, test.typeB, test.expected), func(t *testing.T) {
178+
res := GeneralizeTypes(test.typeA, test.typeB)
179+
assert.Equal(t, test.expected, res)
180+
})
181+
}
182+
}

0 commit comments

Comments
 (0)