@@ -12,6 +12,46 @@ use std::fmt::Display;
1212use std:: sync:: Arc ;
1313
1414use super :: Column ;
15+ use crate :: datatypes:: scalar:: ScalarValue ;
16+
17+ /// Format an expression for *naming* purposes, stripping out CAST(...) (and nested alias) wrappers.
18+ ///
19+ /// Type coercion may insert casts without changing the logical meaning of an expression; we don't
20+ /// want those casts to affect output field names, otherwise downstream column lookups can break
21+ /// (e.g. SUM(a*b) vs SUM(a*CAST(b AS ...))).
22+ fn fmt_expr_for_name ( expr : & LogicalExpr ) -> String {
23+ match expr {
24+ LogicalExpr :: Cast ( c) => fmt_expr_for_name ( & c. expr ) ,
25+ LogicalExpr :: Alias ( a) => fmt_expr_for_name ( & a. expr ) ,
26+ LogicalExpr :: Column ( c) => c. to_string ( ) ,
27+ LogicalExpr :: Literal ( v) => v. to_string ( ) ,
28+ LogicalExpr :: Negative ( e) => format ! ( "- {}" , fmt_expr_for_name( e) ) ,
29+ LogicalExpr :: BinaryExpr ( b) => format ! (
30+ "{} {} {}" ,
31+ fmt_expr_for_name( & b. left) ,
32+ b. op,
33+ fmt_expr_for_name( & b. right)
34+ ) ,
35+ LogicalExpr :: Case ( case) => {
36+ let mut s = String :: from ( "CASE" ) ;
37+ if let Some ( op) = & case. operand {
38+ s. push ( ' ' ) ;
39+ s. push_str ( & fmt_expr_for_name ( op) ) ;
40+ }
41+ for ( w, t) in & case. when_then {
42+ s. push_str ( " WHEN " ) ;
43+ s. push_str ( & fmt_expr_for_name ( w) ) ;
44+ s. push_str ( " THEN " ) ;
45+ s. push_str ( & fmt_expr_for_name ( t) ) ;
46+ }
47+ s. push_str ( " ELSE " ) ;
48+ s. push_str ( & fmt_expr_for_name ( & case. else_expr ) ) ;
49+ s. push_str ( " END" ) ;
50+ s
51+ }
52+ other => other. to_string ( ) ,
53+ }
54+ }
1555
1656#[ derive( Debug , Clone , PartialEq , Eq , Hash ) ]
1757pub enum AggregateOperator {
@@ -96,10 +136,16 @@ pub struct AggregateExpr {
96136impl AggregateExpr {
97137 pub fn field ( & self , plan : & LogicalPlan ) -> Result < FieldRef > {
98138 self . expr . field ( plan) . and_then ( |field| {
99- let col_name = if let LogicalExpr :: Column ( inner) = self . expr . as_ref ( ) {
100- & inner. qualified_name ( )
101- } else {
102- field. name ( )
139+ // Use the *expression string* for non-column arguments, otherwise we may generate
140+ // names like COUNT(i32) from Arrow field names which won't match expression display.
141+ //
142+ // Special case: COUNT(*) is rewritten to COUNT(1) by `CountWildcardRule`, but the
143+ // output column name must remain COUNT(*) for SQL compatibility / tests.
144+ let col_name = match ( self . op . clone ( ) , self . expr . as_ref ( ) ) {
145+ ( AggregateOperator :: Count , LogicalExpr :: Literal ( ScalarValue :: Int32 ( Some ( 1 ) ) ) )
146+ | ( AggregateOperator :: Count , LogicalExpr :: Literal ( ScalarValue :: Int64 ( Some ( 1 ) ) ) ) => "*" . to_string ( ) ,
147+ ( _, LogicalExpr :: Column ( inner) ) => inner. qualified_name ( ) ,
148+ ( _, other) => fmt_expr_for_name ( other) ,
103149 } ;
104150
105151 Ok ( Arc :: new ( Field :: new (
@@ -111,13 +157,30 @@ impl AggregateExpr {
111157 }
112158
113159 pub ( crate ) fn as_column ( & self ) -> Result < LogicalExpr > {
114- self . expr . as_column ( ) . map ( |inner_col| {
115- LogicalExpr :: Column ( Column {
116- name : format ! ( "{}({})" , self . op, inner_col) ,
117- relation : None ,
118- is_outer_ref : false ,
119- } )
120- } )
160+ // Keep COUNT(*) naming stable even if it was rewritten to COUNT(1) internally.
161+ if self . op == AggregateOperator :: Count {
162+ if matches ! (
163+ self . expr. as_ref( ) ,
164+ LogicalExpr :: Literal ( ScalarValue :: Int32 ( Some ( 1 ) ) ) | LogicalExpr :: Literal ( ScalarValue :: Int64 ( Some ( 1 ) ) )
165+ ) {
166+ return Ok ( LogicalExpr :: Column ( Column {
167+ name : "COUNT(*)" . to_string ( ) ,
168+ relation : None ,
169+ is_outer_ref : false ,
170+ } ) ) ;
171+ }
172+ }
173+
174+ let arg_name = match self . expr . as_ref ( ) {
175+ LogicalExpr :: Column ( c) => c. to_string ( ) ,
176+ other => fmt_expr_for_name ( other) ,
177+ } ;
178+
179+ Ok ( LogicalExpr :: Column ( Column {
180+ name : format ! ( "{}({})" , self . op, arg_name) ,
181+ relation : None ,
182+ is_outer_ref : false ,
183+ } ) )
121184 }
122185}
123186
@@ -126,3 +189,117 @@ impl Display for AggregateExpr {
126189 write ! ( f, "{}({})" , self . op, self . expr)
127190 }
128191}
192+
193+ #[ cfg( test) ]
194+ mod tests {
195+ use super :: * ;
196+ use crate :: datatypes:: operator:: Operator ;
197+ use crate :: logical:: expr:: { BinaryExpr , CaseExpr , CastExpr , Column , LogicalExpr } ;
198+ use crate :: logical:: plan:: { EmptyRelation , LogicalPlan } ;
199+ use arrow:: datatypes:: { DataType , Field , Schema } ;
200+ use std:: sync:: Arc ;
201+
202+ fn empty_plan_with_schema ( fields : Vec < Field > ) -> LogicalPlan {
203+ LogicalPlan :: EmptyRelation ( EmptyRelation {
204+ produce_one_row : true ,
205+ schema : Arc :: new ( Schema :: new ( fields) ) ,
206+ } )
207+ }
208+
209+ #[ test]
210+ fn count_star_keeps_output_name_after_rewrite_to_count_1 ( ) {
211+ // Optimizer rule rewrites COUNT(*) -> COUNT(1) for execution.
212+ // However, the output column name must remain COUNT(*) to match SQL surface semantics
213+ // and sqllogictest expectations.
214+ let plan = empty_plan_with_schema ( vec ! [ ] ) ;
215+ let agg = AggregateExpr {
216+ op : AggregateOperator :: Count ,
217+ expr : Box :: new ( LogicalExpr :: Literal ( ScalarValue :: Int32 ( Some ( 1 ) ) ) ) ,
218+ } ;
219+
220+ let field = agg. field ( & plan) . unwrap ( ) ;
221+ assert_eq ! ( field. name( ) , "COUNT(*)" ) ;
222+
223+ let col = agg. as_column ( ) . unwrap ( ) ;
224+ assert_eq ! ( col. to_string( ) , "COUNT(*)" ) ;
225+ }
226+
227+ #[ test]
228+ fn aggregate_naming_ignores_casts_in_argument_expression ( ) {
229+ // TypeCoercion may insert CASTs (e.g. to make DECIMAL * INT valid),
230+ // but we don't want those CASTs to affect the aggregate output column name.
231+ let plan = empty_plan_with_schema ( vec ! [
232+ Field :: new( "a" , DataType :: Decimal128 ( 15 , 2 ) , false ) ,
233+ Field :: new( "b" , DataType :: Int64 , false ) ,
234+ ] ) ;
235+
236+ let expr = LogicalExpr :: BinaryExpr ( BinaryExpr :: new (
237+ LogicalExpr :: Column ( Column :: new (
238+ "a" ,
239+ None :: < crate :: common:: table_relation:: TableRelation > ,
240+ false ,
241+ ) ) ,
242+ Operator :: Mul ,
243+ LogicalExpr :: Cast ( CastExpr :: new (
244+ LogicalExpr :: Column ( Column :: new (
245+ "b" ,
246+ None :: < crate :: common:: table_relation:: TableRelation > ,
247+ false ,
248+ ) ) ,
249+ DataType :: Decimal128 ( 20 , 0 ) ,
250+ ) ) ,
251+ ) ) ;
252+
253+ let agg = AggregateExpr {
254+ op : AggregateOperator :: Sum ,
255+ expr : Box :: new ( expr) ,
256+ } ;
257+
258+ let field = agg. field ( & plan) . unwrap ( ) ;
259+ // cast is ignored for naming: b (not CAST(b AS ...))
260+ assert_eq ! ( field. name( ) , "SUM(a * b)" ) ;
261+ assert_eq ! ( agg. as_column( ) . unwrap( ) . to_string( ) , "SUM(a * b)" ) ;
262+ }
263+
264+ #[ test]
265+ fn aggregate_naming_ignores_casts_inside_case_expression ( ) {
266+ // Similar to TPCH Q8: CASE branch literals may get casted by type coercion,
267+ // but the aggregate output name should stay stable.
268+ let plan = empty_plan_with_schema ( vec ! [
269+ Field :: new( "cond" , DataType :: Boolean , false ) ,
270+ Field :: new( "v" , DataType :: Decimal128 ( 38 , 4 ) , false ) ,
271+ ] ) ;
272+
273+ let case = CaseExpr {
274+ operand : None ,
275+ when_then : vec ! [ (
276+ LogicalExpr :: Column ( Column :: new(
277+ "cond" ,
278+ None :: <crate :: common:: table_relation:: TableRelation >,
279+ false ,
280+ ) ) ,
281+ LogicalExpr :: Column ( Column :: new(
282+ "v" ,
283+ None :: <crate :: common:: table_relation:: TableRelation >,
284+ false ,
285+ ) ) ,
286+ ) ] ,
287+ else_expr : Box :: new ( LogicalExpr :: Cast ( CastExpr :: new (
288+ LogicalExpr :: Literal ( ScalarValue :: Int64 ( Some ( 0 ) ) ) ,
289+ DataType :: Decimal128 ( 38 , 4 ) ,
290+ ) ) ) ,
291+ } ;
292+
293+ let agg = AggregateExpr {
294+ op : AggregateOperator :: Sum ,
295+ expr : Box :: new ( LogicalExpr :: Case ( case) ) ,
296+ } ;
297+
298+ let field = agg. field ( & plan) . unwrap ( ) ;
299+ assert_eq ! ( field. name( ) , "SUM(CASE WHEN cond THEN v ELSE Int64(0) END)" ) ;
300+ assert_eq ! (
301+ agg. as_column( ) . unwrap( ) . to_string( ) ,
302+ "SUM(CASE WHEN cond THEN v ELSE Int64(0) END)"
303+ ) ;
304+ }
305+ }
0 commit comments