Skip to content

Commit cb66eeb

Browse files
authored
Merge pull request #3025 from dolthub/angela/groupby
Convert if.Eval result to correct type
2 parents 9688af8 + 04150db commit cb66eeb

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

enginetest/queries/order_by_group_by_queries.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,17 @@ var OrderByGroupByScriptTests = []ScriptTest{
305305
},
306306
},
307307
},
308+
{
309+
Name: "Group by true and 1",
310+
SetUpScript: []string{
311+
"create table t0(c0 int)",
312+
"insert into t0(c0) values(1),(123)",
313+
},
314+
Assertions: []ScriptTestAssertion{
315+
{
316+
Query: "select if(t0.c0 = 123, TRUE, t0.c0) AS ref0, min(t0.c0) as ref1 from t0 group by ref0",
317+
Expected: []sql.Row{{1, 1}},
318+
},
319+
},
320+
},
308321
}

enginetest/queries/queries.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6123,7 +6123,7 @@ SELECT * FROM cte WHERE d = 2;`,
61236123
{
61246124
Query: `SELECT if(0, "abc", 456)`,
61256125
Expected: []sql.Row{
6126-
{456},
6126+
{"456"},
61276127
},
61286128
},
61296129
{
@@ -9768,7 +9768,7 @@ from typestable`,
97689768
{
97699769
Query: "select if('', 1, char(''));",
97709770
Expected: []sql.Row{
9771-
{[]byte{0}},
9771+
{"\x00"},
97729772
},
97739773
},
97749774
{

sql/expression/function/if.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,20 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
7777
}
7878
}
7979

80+
var eval interface{}
8081
if asBool {
81-
return f.ifTrue.Eval(ctx, row)
82+
eval, err = f.ifTrue.Eval(ctx, row)
83+
if err != nil {
84+
return nil, err
85+
}
8286
} else {
83-
return f.ifFalse.Eval(ctx, row)
87+
eval, err = f.ifFalse.Eval(ctx, row)
88+
if err != nil {
89+
return nil, err
90+
}
8491
}
92+
eval, _, err = f.Type().Convert(ctx, eval)
93+
return eval, err
8594
}
8695

8796
// Type implements the Expression interface.

sql/expression/function/if_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,22 @@ func TestIf(t *testing.T) {
2929
expr sql.Expression
3030
row sql.Row
3131
expected interface{}
32+
type1 sql.Type
33+
type2 sql.Type
3234
}{
33-
{eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "a"},
34-
{eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{"a", "b"}, "b"},
35-
{eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{1, 2}, 1},
36-
{eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{1, 2}, 2},
37-
{eq(lit(nil, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "b"},
38-
{eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{nil, "b"}, nil},
35+
{eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "a", types.Text, types.Text},
36+
{eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{"a", "b"}, "b", types.Text, types.Text},
37+
{eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{1, 2}, int64(1), types.Int64, types.Int64},
38+
{eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{1, 2}, int64(2), types.Int64, types.Int64},
39+
{eq(lit(nil, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "b", types.Text, types.Text},
40+
{eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{nil, "b"}, nil, nil, types.Text},
3941
}
4042

4143
for _, tc := range testCases {
4244
f := NewIf(
4345
tc.expr,
44-
expression.NewGetField(0, types.LongText, "true", true),
45-
expression.NewGetField(1, types.LongText, "false", true),
46+
expression.NewGetField(0, tc.type1, "true", true),
47+
expression.NewGetField(1, tc.type2, "false", true),
4648
)
4749

4850
v, err := f.Eval(sql.NewEmptyContext(), tc.row)

0 commit comments

Comments
 (0)