@@ -80,7 +80,7 @@ pub struct AggregateInfo {
8080 pub group_items_map : HashMap < String , usize > ,
8181
8282 /// Index for virtual column `grouping_id`. It's valid only if `grouping_sets` is not empty.
83- pub grouping_id_index : IndexType ,
83+ pub grouping_id_column : Option < ColumnBinding > ,
8484 /// Each grouping set is a list of column indices in `group_items`.
8585 pub grouping_sets : Vec < Vec < IndexType > > ,
8686}
@@ -124,6 +124,9 @@ impl<'a> AggregateRewriter<'a> {
124124 }
125125 . into ( ) ) ,
126126 ScalarExpr :: FunctionCall ( func) => {
127+ if func. func_name . eq_ignore_ascii_case ( "grouping" ) {
128+ return self . replace_grouping ( func) ;
129+ }
127130 let new_args = func
128131 . arguments
129132 . iter ( )
@@ -225,6 +228,46 @@ impl<'a> AggregateRewriter<'a> {
225228
226229 Ok ( replaced_agg. into ( ) )
227230 }
231+
232+ fn replace_grouping ( & mut self , function : & FunctionCall ) -> Result < ScalarExpr > {
233+ let agg_info = & mut self . bind_context . aggregate_info ;
234+ if agg_info. grouping_id_column . is_none ( ) {
235+ return Err ( ErrorCode :: SemanticError (
236+ "grouping can only be called in GROUP BY GROUPING SETS clauses" ,
237+ ) ) ;
238+ }
239+ let grouping_id_column = agg_info. grouping_id_column . clone ( ) . unwrap ( ) ;
240+
241+ // Rewrite the args to params.
242+ // The params are the index offset in `grouping_id`.
243+ // Here is an example:
244+ // If the query is `select grouping(b, a) from group by grouping sets ((a, b), (a));`
245+ // The group-by items are: [a, b].
246+ // The group ids will be (a: 0, b: 1):
247+ // ba -> 00 -> 0
248+ // _a -> 01 -> 1
249+ // grouping(b, a) will be rewritten to grouping<1, 0>(grouping_id).
250+ let mut replaced_params = Vec :: with_capacity ( function. arguments . len ( ) ) ;
251+ for arg in & function. arguments {
252+ if let Some ( index) = agg_info. group_items_map . get ( & format ! ( "{:?}" , arg) ) {
253+ replaced_params. push ( * index) ;
254+ } else {
255+ return Err ( ErrorCode :: BadArguments (
256+ "Arguments of grouping should be group by expressions" ,
257+ ) ) ;
258+ }
259+ }
260+
261+ let replaced_func = FunctionCall {
262+ func_name : function. func_name . clone ( ) ,
263+ params : replaced_params,
264+ arguments : vec ! [ ScalarExpr :: BoundColumnRef ( BoundColumnRef {
265+ column: grouping_id_column,
266+ } ) ] ,
267+ } ;
268+
269+ Ok ( replaced_func. into ( ) )
270+ }
228271}
229272
230273impl Binder {
@@ -331,8 +374,12 @@ impl Binder {
331374 aggregate_functions : bind_context. aggregate_info . aggregate_functions . clone ( ) ,
332375 from_distinct : false ,
333376 limit : None ,
334- grouping_id_index : agg_info. grouping_id_index ,
335377 grouping_sets : agg_info. grouping_sets . clone ( ) ,
378+ grouping_id_index : agg_info
379+ . grouping_id_column
380+ . as_ref ( )
381+ . map ( |g| g. index )
382+ . unwrap_or ( 0 ) ,
336383 } ;
337384 new_expr = SExpr :: create_unary ( aggregate_plan. into ( ) , new_expr) ;
338385
@@ -358,15 +405,16 @@ impl Binder {
358405 )
359406 . await ?;
360407 }
408+ let agg_info = & mut bind_context. aggregate_info ;
361409 // `grouping_sets` stores formatted `ScalarExpr` for each grouping set.
362410 let grouping_sets = grouping_sets
363411 . into_iter ( )
364412 . map ( |set| {
365413 let mut set = set
366414 . into_iter ( )
367415 . map ( |s| {
368- let offset = * bind_context . aggregate_info . group_items_map . get ( & s) . unwrap ( ) ;
369- bind_context . aggregate_info . group_items [ offset] . index
416+ let offset = * agg_info . group_items_map . get ( & s) . unwrap ( ) ;
417+ agg_info . group_items [ offset] . index
370418 } )
371419 . collect :: < Vec < _ > > ( ) ;
372420 // Grouping sets with the same items should be treated as the same.
@@ -375,7 +423,7 @@ impl Binder {
375423 } )
376424 . collect :: < Vec < _ > > ( ) ;
377425 let grouping_sets = grouping_sets. into_iter ( ) . unique ( ) . collect ( ) ;
378- bind_context . aggregate_info . grouping_sets = grouping_sets;
426+ agg_info . grouping_sets = grouping_sets;
379427 // Add a virtual column `_grouping_id` to group items.
380428 let grouping_id_column = self . create_column_binding (
381429 None ,
@@ -384,8 +432,17 @@ impl Binder {
384432 DataType :: Number ( NumberDataType :: UInt32 ) ,
385433 ) ;
386434 let index = grouping_id_column. index ;
387- bind_context. aggregate_info . grouping_id_index = index;
388- bind_context. aggregate_info . group_items . push ( ScalarItem {
435+ agg_info. grouping_id_column = Some ( grouping_id_column. clone ( ) ) ;
436+ agg_info. group_items_map . insert (
437+ format ! (
438+ "{:?}" ,
439+ ScalarExpr :: BoundColumnRef ( BoundColumnRef {
440+ column: grouping_id_column. clone( )
441+ } )
442+ ) ,
443+ agg_info. group_items . len ( ) ,
444+ ) ;
445+ agg_info. group_items . push ( ScalarItem {
389446 index,
390447 scalar : ScalarExpr :: BoundColumnRef ( BoundColumnRef {
391448 column : grouping_id_column,
@@ -485,6 +542,11 @@ impl Binder {
485542 ) ;
486543 }
487544
545+ // If it's `GROUP BY GROUPING SETS`, ignore the optimization below.
546+ if collect_grouping_sets {
547+ return Ok ( ( ) ) ;
548+ }
549+
488550 // Remove dependent group items, group by a, f(a, b), f(a), b ---> group by a,b
489551 let mut results = vec ! [ ] ;
490552 for item in bind_context. aggregate_info . group_items . iter ( ) {
0 commit comments