Skip to content

Commit 23e9216

Browse files
authored
Merge pull request #10592 from RinChanNOWWW/grouping
feat: new function `grouping`.
2 parents 218bf9e + 65af315 commit 23e9216

File tree

18 files changed

+232
-62
lines changed

18 files changed

+232
-62
lines changed

src/query/functions/src/scalars/other.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@ use common_expression::types::DateType;
3535
use common_expression::types::GenericType;
3636
use common_expression::types::NullType;
3737
use common_expression::types::NullableType;
38+
use common_expression::types::NumberColumn;
39+
use common_expression::types::NumberDataType;
40+
use common_expression::types::NumberScalar;
3841
use common_expression::types::NumberType;
3942
use common_expression::types::SimpleDomain;
4043
use common_expression::types::StringType;
4144
use common_expression::types::TimestampType;
4245
use common_expression::types::ValueType;
4346
use common_expression::vectorize_with_builder_1_arg;
47+
use common_expression::Column;
4448
use common_expression::Domain;
4549
use common_expression::EvalContext;
4650
use common_expression::Function;
@@ -49,6 +53,7 @@ use common_expression::FunctionProperty;
4953
use common_expression::FunctionRegistry;
5054
use common_expression::FunctionSignature;
5155
use common_expression::Scalar;
56+
use common_expression::ScalarRef;
5257
use common_expression::Value;
5358
use common_expression::ValueRef;
5459
use ordered_float::OrderedFloat;
@@ -66,6 +71,7 @@ pub fn register(registry: &mut FunctionRegistry) {
6671
register_inet_aton(registry);
6772
register_inet_ntoa(registry);
6873
register_run_diff(registry);
74+
register_grouping(registry);
6975

7076
registry.register_passthrough_nullable_1_arg::<Float64Type, StringType, _, _>(
7177
"humanize_size",
@@ -343,3 +349,48 @@ fn register_run_diff(registry: &mut FunctionRegistry) {
343349
OrderedFloat(0.0)
344350
);
345351
}
352+
353+
fn register_grouping(registry: &mut FunctionRegistry) {
354+
registry.register_function_factory("grouping", |params, arg_type| {
355+
if arg_type.len() != 1 {
356+
return None;
357+
}
358+
359+
let params = params.to_vec();
360+
Some(Arc::new(Function {
361+
signature: FunctionSignature {
362+
name: "grouping".to_string(),
363+
args_type: vec![DataType::Number(NumberDataType::UInt32)],
364+
return_type: DataType::Number(NumberDataType::UInt32),
365+
property: FunctionProperty::default(),
366+
},
367+
calc_domain: Box::new(|_| FunctionDomain::Full),
368+
eval: Box::new(move |args, _| match &args[0] {
369+
ValueRef::Scalar(ScalarRef::Number(NumberScalar::UInt32(v))) => Value::Scalar(
370+
Scalar::Number(NumberScalar::UInt32(compute_grouping(&params, *v))),
371+
),
372+
ValueRef::Column(Column::Number(NumberColumn::UInt32(col))) => {
373+
let output = col
374+
.iter()
375+
.map(|v| compute_grouping(&params, *v))
376+
.collect::<Vec<_>>();
377+
Value::Column(Column::Number(NumberColumn::UInt32(output.into())))
378+
}
379+
_ => unreachable!(),
380+
}),
381+
}))
382+
})
383+
}
384+
385+
/// Compute `grouping` by `grouping_id` and `cols`.
386+
///
387+
/// `cols` are indices of the column represented in `_grouping_id`.
388+
/// The order will influence the result of `grouping`.
389+
#[inline(always)]
390+
pub fn compute_grouping(cols: &[usize], grouping_id: u32) -> u32 {
391+
let mut grouping = 0;
392+
for (i, &j) in cols.iter().rev().enumerate() {
393+
grouping |= ((grouping_id & (1 << j)) >> j) << i;
394+
}
395+
grouping
396+
}

src/query/functions/tests/it/scalars/testdata/function_list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,7 @@ Functions overloads:
16101610
1 great_circle_angle(Float64 NULL, Float64 NULL, Float64 NULL, Float64 NULL) :: Float32 NULL
16111611
0 great_circle_distance(Float64, Float64, Float64, Float64) :: Float32
16121612
1 great_circle_distance(Float64 NULL, Float64 NULL, Float64 NULL, Float64 NULL) :: Float32 NULL
1613+
0 grouping FACTORY
16131614
0 gt(Variant, Variant) :: Boolean
16141615
1 gt(Variant NULL, Variant NULL) :: Boolean NULL
16151616
2 gt(String, String) :: Boolean

src/query/service/src/pipelines/pipeline_builder.rs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -434,19 +434,11 @@ impl PipelineBuilder {
434434
let group_bys = expand
435435
.group_bys
436436
.iter()
437-
.filter_map(|i| {
438-
// Do not collect virtual column "_grouping_id".
439-
if *i != expand.grouping_id_index {
440-
match input_schema.index_of(&i.to_string()) {
441-
Ok(index) => {
442-
let ty = input_schema.field(index).data_type().clone();
443-
Some(Ok((index, ty)))
444-
}
445-
Err(e) => Some(Err(e)),
446-
}
447-
} else {
448-
None
449-
}
437+
.take(expand.group_bys.len() - 1) // The last group-by will be virtual column `_grouping_id`
438+
.map(|i| {
439+
let index = input_schema.index_of(&i.to_string())?;
440+
let ty = input_schema.field(index).data_type();
441+
Ok((index, ty.clone()))
450442
})
451443
.collect::<Result<Vec<_>>>()?;
452444
let grouping_sets = expand
@@ -463,6 +455,7 @@ impl PipelineBuilder {
463455
})
464456
.collect::<Result<Vec<_>>>()?;
465457
let mut grouping_ids = Vec::with_capacity(grouping_sets.len());
458+
let mask = (1 << group_bys.len()) - 1;
466459
for set in grouping_sets {
467460
let mut id = 0;
468461
for i in set {
@@ -474,7 +467,7 @@ impl PipelineBuilder {
474467
// group_bys: [a, b]
475468
// grouping_sets: [[0, 1], [0], [1], []]
476469
// grouping_ids: 00, 01, 10, 11
477-
grouping_ids.push(!id);
470+
grouping_ids.push(!id & mask);
478471
}
479472

480473
self.main_pipeline.add_transform(|input, output| {

src/query/sql/src/executor/physical_plan.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,24 +174,29 @@ pub struct AggregateExpand {
174174
pub plan_id: u32,
175175

176176
pub input: Box<PhysicalPlan>,
177-
pub group_bys: Vec<usize>,
177+
pub group_bys: Vec<IndexType>,
178178
pub grouping_id_index: IndexType,
179-
pub grouping_sets: Vec<Vec<usize>>,
179+
pub grouping_sets: Vec<Vec<IndexType>>,
180180
/// Only used for explain
181181
pub stat_info: Option<PlanStatsInfo>,
182182
}
183183

184184
impl AggregateExpand {
185185
pub fn output_schema(&self) -> Result<DataSchemaRef> {
186186
let input_schema = self.input.output_schema()?;
187-
let input_fields = input_schema.fields();
188-
let mut output_fields = Vec::with_capacity(input_fields.len() + 1);
189-
for field in input_fields {
190-
output_fields.push(DataField::new(
191-
field.name(),
192-
field.data_type().wrap_nullable(),
193-
));
187+
let mut output_fields = input_schema.fields().clone();
188+
189+
for group_by in self
190+
.group_bys
191+
.iter()
192+
.filter(|&index| *index != self.grouping_id_index)
193+
{
194+
// All group by columns will wrap nullable.
195+
let i = input_schema.index_of(&group_by.to_string())?;
196+
let f = &mut output_fields[i];
197+
*f = DataField::new(f.name(), f.data_type().wrap_nullable())
194198
}
199+
195200
output_fields.push(DataField::new(
196201
&self.grouping_id_index.to_string(),
197202
DataType::Number(NumberDataType::UInt32),

src/query/sql/src/executor/physical_plan_builder.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,7 @@ impl PhysicalPlanBuilder {
497497
output_column: v.index,
498498
args: agg.args.iter().map(|arg| {
499499
if let ScalarExpr::BoundColumnRef(col) = arg {
500-
let col_index = input_schema.index_of(&col.column.index.to_string())?;
501-
Ok(col_index)
500+
input_schema.index_of(&col.column.index.to_string())
502501
} else {
503502
Err(ErrorCode::Internal(
504503
"Aggregate function argument must be a BoundColumnRef".to_string()

src/query/sql/src/planner/binder/aggregate.rs

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub struct AggregateInfo {
8080
pub group_items_map: HashMap<String, usize>,
8181

8282
/// Index for virtual column `grouping_id`. It's valid only if `grouping_sets` is not empty.
83-
pub grouping_id_index: IndexType,
83+
pub grouping_id_column: Option<ColumnBinding>,
8484
/// Each grouping set is a list of column indices in `group_items`.
8585
pub grouping_sets: Vec<Vec<IndexType>>,
8686
}
@@ -124,6 +124,9 @@ impl<'a> AggregateRewriter<'a> {
124124
}
125125
.into()),
126126
ScalarExpr::FunctionCall(func) => {
127+
if func.func_name.eq_ignore_ascii_case("grouping") {
128+
return self.replace_grouping(func);
129+
}
127130
let new_args = func
128131
.arguments
129132
.iter()
@@ -225,6 +228,46 @@ impl<'a> AggregateRewriter<'a> {
225228

226229
Ok(replaced_agg.into())
227230
}
231+
232+
fn replace_grouping(&mut self, function: &FunctionCall) -> Result<ScalarExpr> {
233+
let agg_info = &mut self.bind_context.aggregate_info;
234+
if agg_info.grouping_id_column.is_none() {
235+
return Err(ErrorCode::SemanticError(
236+
"grouping can only be called in GROUP BY GROUPING SETS clauses",
237+
));
238+
}
239+
let grouping_id_column = agg_info.grouping_id_column.clone().unwrap();
240+
241+
// Rewrite the args to params.
242+
// The params are the index offset in `grouping_id`.
243+
// Here is an example:
244+
// If the query is `select grouping(b, a) from group by grouping sets ((a, b), (a));`
245+
// The group-by items are: [a, b].
246+
// The group ids will be (a: 0, b: 1):
247+
// ba -> 00 -> 0
248+
// _a -> 01 -> 1
249+
// grouping(b, a) will be rewritten to grouping<1, 0>(grouping_id).
250+
let mut replaced_params = Vec::with_capacity(function.arguments.len());
251+
for arg in &function.arguments {
252+
if let Some(index) = agg_info.group_items_map.get(&format!("{:?}", arg)) {
253+
replaced_params.push(*index);
254+
} else {
255+
return Err(ErrorCode::BadArguments(
256+
"Arguments of grouping should be group by expressions",
257+
));
258+
}
259+
}
260+
261+
let replaced_func = FunctionCall {
262+
func_name: function.func_name.clone(),
263+
params: replaced_params,
264+
arguments: vec![ScalarExpr::BoundColumnRef(BoundColumnRef {
265+
column: grouping_id_column,
266+
})],
267+
};
268+
269+
Ok(replaced_func.into())
270+
}
228271
}
229272

230273
impl Binder {
@@ -331,8 +374,12 @@ impl Binder {
331374
aggregate_functions: bind_context.aggregate_info.aggregate_functions.clone(),
332375
from_distinct: false,
333376
limit: None,
334-
grouping_id_index: agg_info.grouping_id_index,
335377
grouping_sets: agg_info.grouping_sets.clone(),
378+
grouping_id_index: agg_info
379+
.grouping_id_column
380+
.as_ref()
381+
.map(|g| g.index)
382+
.unwrap_or(0),
336383
};
337384
new_expr = SExpr::create_unary(aggregate_plan.into(), new_expr);
338385

@@ -358,15 +405,16 @@ impl Binder {
358405
)
359406
.await?;
360407
}
408+
let agg_info = &mut bind_context.aggregate_info;
361409
// `grouping_sets` stores formatted `ScalarExpr` for each grouping set.
362410
let grouping_sets = grouping_sets
363411
.into_iter()
364412
.map(|set| {
365413
let mut set = set
366414
.into_iter()
367415
.map(|s| {
368-
let offset = *bind_context.aggregate_info.group_items_map.get(&s).unwrap();
369-
bind_context.aggregate_info.group_items[offset].index
416+
let offset = *agg_info.group_items_map.get(&s).unwrap();
417+
agg_info.group_items[offset].index
370418
})
371419
.collect::<Vec<_>>();
372420
// Grouping sets with the same items should be treated as the same.
@@ -375,7 +423,7 @@ impl Binder {
375423
})
376424
.collect::<Vec<_>>();
377425
let grouping_sets = grouping_sets.into_iter().unique().collect();
378-
bind_context.aggregate_info.grouping_sets = grouping_sets;
426+
agg_info.grouping_sets = grouping_sets;
379427
// Add a virtual column `_grouping_id` to group items.
380428
let grouping_id_column = self.create_column_binding(
381429
None,
@@ -384,8 +432,17 @@ impl Binder {
384432
DataType::Number(NumberDataType::UInt32),
385433
);
386434
let index = grouping_id_column.index;
387-
bind_context.aggregate_info.grouping_id_index = index;
388-
bind_context.aggregate_info.group_items.push(ScalarItem {
435+
agg_info.grouping_id_column = Some(grouping_id_column.clone());
436+
agg_info.group_items_map.insert(
437+
format!(
438+
"{:?}",
439+
ScalarExpr::BoundColumnRef(BoundColumnRef {
440+
column: grouping_id_column.clone()
441+
})
442+
),
443+
agg_info.group_items.len(),
444+
);
445+
agg_info.group_items.push(ScalarItem {
389446
index,
390447
scalar: ScalarExpr::BoundColumnRef(BoundColumnRef {
391448
column: grouping_id_column,
@@ -485,6 +542,11 @@ impl Binder {
485542
);
486543
}
487544

545+
// If it's `GROUP BY GROUPING SETS`, ignore the optimization below.
546+
if collect_grouping_sets {
547+
return Ok(());
548+
}
549+
488550
// Remove dependent group items, group by a, f(a, b), f(a), b ---> group by a,b
489551
let mut results = vec![];
490552
for item in bind_context.aggregate_info.group_items.iter() {

src/query/sql/src/planner/binder/select.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ impl Binder {
101101
.normalize_select_list(&mut from_context, &stmt.select_list)
102102
.await?;
103103

104-
let (mut scalar_items, projections) = self.analyze_projection(&select_list)?;
105-
106104
// This will potentially add some alias group items to `from_context` if find some.
107105
if let Some(group_by) = stmt.group_by.as_ref() {
108106
self.analyze_group_items(&mut from_context, &select_list, group_by)
@@ -111,6 +109,9 @@ impl Binder {
111109

112110
self.analyze_aggregate_select(&mut from_context, &mut select_list)?;
113111

112+
// `analyze_projection` should behind `analyze_aggregate_select` because `analyze_aggregate_select` will rewrite `grouping`.
113+
let (mut scalar_items, projections) = self.analyze_projection(&select_list)?;
114+
114115
let having = if let Some(having) = &stmt.having {
115116
Some(
116117
self.analyze_aggregate_having(&mut from_context, &select_list, having)

src/query/sql/src/planner/semantic/grouping_check.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ impl<'a> GroupingChecker<'a> {
4949
.get(&format!("{:?}", scalar))
5050
{
5151
let column = &self.bind_context.aggregate_info.group_items[*index];
52-
let column_binding = if let ScalarExpr::BoundColumnRef(column_ref) = &column.scalar {
52+
let mut column_binding = if let ScalarExpr::BoundColumnRef(column_ref) = &column.scalar
53+
{
5354
column_ref.column.clone()
5455
} else {
5556
ColumnBinding {
@@ -61,6 +62,13 @@ impl<'a> GroupingChecker<'a> {
6162
visibility: Visibility::Visible,
6263
}
6364
};
65+
66+
if let Some(grouping_id) = &self.bind_context.aggregate_info.grouping_id_column {
67+
if grouping_id.index != column_binding.index {
68+
column_binding.data_type = Box::new(column_binding.data_type.wrap_nullable());
69+
}
70+
}
71+
6472
return Ok(BoundColumnRef {
6573
column: column_binding,
6674
}

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,18 @@ impl<'a> TypeChecker<'a> {
991991
Self::rewrite_substring(&mut args);
992992
}
993993

994+
if func_name == "grouping" {
995+
// `grouping` will be rewritten again after resolving grouping sets.
996+
return Ok(Box::new((
997+
ScalarExpr::FunctionCall(FunctionCall {
998+
params: vec![],
999+
arguments: args,
1000+
func_name: "grouping".to_string(),
1001+
}),
1002+
DataType::Number(NumberDataType::UInt32),
1003+
)));
1004+
}
1005+
9941006
// rewrite_collation
9951007
let func_name = if self.function_need_collation(func_name, &args)?
9961008
&& self.ctx.get_settings().get_collation()? == "utf8"

0 commit comments

Comments
 (0)