Skip to content

Commit a02c999

Browse files
authored
chore(query): fix grouping function in type checker (#18502)
* chore(query): fix grouping function in type checker * chore(query): fix grouping function in type checker * chore(query): fix grouping function in type checker
1 parent 7a90a83 commit a02c999

File tree

5 files changed

+42
-28
lines changed

5 files changed

+42
-28
lines changed

src/query/expression/src/type_check.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ pub fn check_function<Index: ColumnIndex>(
316316

317317
// Do not check grouping
318318
if name == "grouping" {
319-
debug_assert!(candidates.len() == 1);
319+
debug_assert!(!candidates.is_empty());
320320
let (id, function) = candidates.into_iter().next().unwrap();
321321
let return_type = function.signature.return_type.clone();
322322
return Ok(Expr::FunctionCall(FunctionCall {

src/query/functions/src/scalars/other.rs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,29 @@ fn register_grouping(registry: &mut FunctionRegistry) {
391391
},
392392
}))
393393
}));
394-
registry.register_function_factory("grouping", grouping)
394+
registry.register_function_factory("grouping", grouping);
395+
396+
// dummy grouping
397+
// used in type_check before AggregateRewriter
398+
let dummy_grouping = FunctionFactory::Closure(Box::new(|_, arg_type: &[DataType]| {
399+
Some(Arc::new(Function {
400+
signature: FunctionSignature {
401+
name: "grouping".to_string(),
402+
args_type: vec![DataType::Generic(0); arg_type.len()],
403+
return_type: DataType::Number(NumberDataType::UInt32),
404+
},
405+
eval: FunctionEval::Scalar {
406+
calc_domain: Box::new(|_, _| FunctionDomain::Full),
407+
eval: Box::new(move |args, _| {
408+
unreachable!(
409+
"grouping function must be rewritten in type_checker, but got: {:?}",
410+
args
411+
)
412+
}),
413+
},
414+
}))
415+
}));
416+
registry.register_function_factory("grouping", dummy_grouping);
395417
}
396418

397419
fn register_num_to_char(registry: &mut FunctionRegistry) {

src/query/functions/tests/it/scalars/testdata/function_list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1998,6 +1998,7 @@ Functions overloads:
19981998
0 great_circle_distance(Float64, Float64, Float64, Float64) :: Float32
19991999
1 great_circle_distance(Float64 NULL, Float64 NULL, Float64 NULL, Float64 NULL) :: Float32 NULL
20002000
0 grouping FACTORY
2001+
1 grouping FACTORY
20012002
0 gt(Variant, Variant) :: Boolean
20022003
1 gt(Variant NULL, Variant NULL) :: Boolean NULL
20032004
2 gt(String, String) :: Boolean

src/query/sql/src/planner/semantic/type_check.rs

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2906,30 +2906,6 @@ impl<'a> TypeChecker<'a> {
29062906
arg_types.push(arg_type);
29072907
}
29082908

2909-
// rewrite substr('xx', 0, xx) -> substr('xx', 1, xx)
2910-
if (func_name == "substr" || func_name == "substring")
2911-
&& self
2912-
.ctx
2913-
.get_settings()
2914-
.get_sql_dialect()
2915-
.unwrap()
2916-
.substr_index_zero_literal_as_one()
2917-
{
2918-
Self::rewrite_substring(&mut args);
2919-
}
2920-
2921-
if func_name == "grouping" {
2922-
// `grouping` will be rewritten again after resolving grouping sets.
2923-
return Ok(Box::new((
2924-
ScalarExpr::FunctionCall(FunctionCall {
2925-
span,
2926-
params: vec![],
2927-
arguments: args,
2928-
func_name: "grouping".to_string(),
2929-
}),
2930-
DataType::Number(NumberDataType::UInt32),
2931-
)));
2932-
}
29332909
if let Some(rewritten_variant_expr) =
29342910
self.try_rewrite_variant_function(span, func_name, &args, &arg_types)
29352911
{
@@ -2949,11 +2925,22 @@ impl<'a> TypeChecker<'a> {
29492925
span: Span,
29502926
func_name: &str,
29512927
mut params: Vec<Scalar>,
2952-
args: Vec<ScalarExpr>,
2928+
mut args: Vec<ScalarExpr>,
29532929
) -> Result<Box<(ScalarExpr, DataType)>> {
2930+
// rewrite substr('xx', 0, xx) -> substr('xx', 1, xx)
2931+
if (func_name == "substr" || func_name == "substring")
2932+
&& self
2933+
.ctx
2934+
.get_settings()
2935+
.get_sql_dialect()
2936+
.unwrap()
2937+
.substr_index_zero_literal_as_one()
2938+
{
2939+
Self::rewrite_substring(&mut args);
2940+
}
2941+
29542942
// Type check
29552943
let mut arguments = args.iter().map(|v| v.as_raw_expr()).collect::<Vec<_>>();
2956-
29572944
// inject the params
29582945
if ["round", "truncate"].contains(&func_name)
29592946
&& !args.is_empty()

tests/sqllogictests/suites/duckdb/sql/aggregate/group/group_by_grouping_sets.test

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ statement ok
1717
select sum(number), number % 3 a, grouping(number % 3)+grouping(number % 4) AS lochierarchy from numbers(10)
1818
group by rollup(number % 3, number % 4) order by grouping(number % 3)+grouping(number % 4) ;
1919

20+
2021
query TT
2122
select number % 2 as a, number % 3 as b from numbers(24) group by grouping sets ((a,b), (a), (b)) order by a,b;
2223
----
@@ -91,6 +92,9 @@ NULL B 10 0 1 2 1
9192
b NULL 11 1 0 1 2
9293
NULL NULL 18 1 1 3 3
9394

95+
statement ok
96+
select a, b, sum(c) as sc, grouping(a,b) + 3, grouping(b,a) > 2 from t group by grouping sets ((a,b),(),(b),(a)) order by sc;
97+
9498
query TTIIIII
9599
select a, b, sum(c) as sc, grouping(b), grouping(a), grouping(a,b), grouping(b,a) from t group by grouping sets ((1,2),(),(2),(1)) order by sc;
96100
----

0 commit comments

Comments
 (0)