From 85a753d0b377e5f7e45e0c86afe385574a25b25f Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Wed, 11 Jun 2025 15:01:31 -0700 Subject: [PATCH 1/2] convert if eval result to correct type --- enginetest/queries/order_by_group_by_queries.go | 13 +++++++++++++ enginetest/queries/queries.go | 4 ++-- sql/expression/function/if.go | 13 +++++++++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/enginetest/queries/order_by_group_by_queries.go b/enginetest/queries/order_by_group_by_queries.go index 84de8445ea..579244cc7f 100644 --- a/enginetest/queries/order_by_group_by_queries.go +++ b/enginetest/queries/order_by_group_by_queries.go @@ -305,4 +305,17 @@ var OrderByGroupByScriptTests = []ScriptTest{ }, }, }, + { + Name: "Group by true and 1", + SetUpScript: []string{ + "create table t0(c0 int)", + "insert into t0(c0) values(1),(123)", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "select if(t0.c0 = 123, TRUE, t0.c0) AS ref0, min(t0.c0) as ref1 from t0 group by ref0", + Expected: []sql.Row{{1, 1}}, + }, + }, + }, } diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 9e350bbeee..46e7f54d43 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -6123,7 +6123,7 @@ SELECT * FROM cte WHERE d = 2;`, { Query: `SELECT if(0, "abc", 456)`, Expected: []sql.Row{ - {456}, + {"456"}, }, }, { @@ -9696,7 +9696,7 @@ from typestable`, { Query: "select if('', 1, char(''));", Expected: []sql.Row{ - {[]byte{0}}, + {"\x00"}, }, }, { diff --git a/sql/expression/function/if.go b/sql/expression/function/if.go index ebbe34a02b..55e24e5fdf 100644 --- a/sql/expression/function/if.go +++ b/sql/expression/function/if.go @@ -77,11 +77,20 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + var eval interface{} if asBool { - return f.ifTrue.Eval(ctx, row) + eval, err = f.ifTrue.Eval(ctx, row) + if err != nil { + return nil, err + } } else { - return f.ifFalse.Eval(ctx, row) + eval, err = f.ifFalse.Eval(ctx, row) + if err != nil { + return nil, err + } } + eval, _, err = f.Type().Convert(ctx, eval) + return eval, err } // Type implements the Expression interface. From 04150db56225ebd3d97afe2e1a1f44362c4a6369 Mon Sep 17 00:00:00 2001 From: Angela Xie Date: Wed, 11 Jun 2025 15:45:45 -0700 Subject: [PATCH 2/2] fix if test --- sql/expression/function/if_test.go | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sql/expression/function/if_test.go b/sql/expression/function/if_test.go index 2559f40438..946912ff46 100644 --- a/sql/expression/function/if_test.go +++ b/sql/expression/function/if_test.go @@ -29,20 +29,22 @@ func TestIf(t *testing.T) { expr sql.Expression row sql.Row expected interface{} + type1 sql.Type + type2 sql.Type }{ - {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "a"}, - {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{"a", "b"}, "b"}, - {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{1, 2}, 1}, - {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{1, 2}, 2}, - {eq(lit(nil, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "b"}, - {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{nil, "b"}, nil}, + {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "a", types.Text, types.Text}, + {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{"a", "b"}, "b", types.Text, types.Text}, + {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{1, 2}, int64(1), types.Int64, types.Int64}, + {eq(lit(1, types.Int64), lit(0, types.Int64)), sql.Row{1, 2}, int64(2), types.Int64, types.Int64}, + {eq(lit(nil, types.Int64), lit(1, types.Int64)), sql.Row{"a", "b"}, "b", types.Text, types.Text}, + {eq(lit(1, types.Int64), lit(1, types.Int64)), sql.Row{nil, "b"}, nil, nil, types.Text}, } for _, tc := range testCases { f := NewIf( tc.expr, - expression.NewGetField(0, types.LongText, "true", true), - expression.NewGetField(1, types.LongText, "false", true), + expression.NewGetField(0, tc.type1, "true", true), + expression.NewGetField(1, tc.type2, "false", true), ) v, err := f.Eval(sql.NewEmptyContext(), tc.row)