diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 287a133273d8..8e3bcb255151 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -589,33 +589,16 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); - let aggr_expr_len = aggr_expr.len(); let options = LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); let plan = LogicalPlanBuilder::from(self.plan) .with_options(options) .aggregate(group_expr, aggr_expr)? .build()?; - let plan = if is_grouping_set { - let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len; - // For grouping sets we do a project to not expose the internal grouping id - let exprs = plan - .schema() - .columns() - .into_iter() - .enumerate() - .filter(|(idx, _)| *idx != grouping_id_pos) - .map(|(_, column)| Expr::Column(column)) - .collect::>(); - LogicalPlanBuilder::from(plan).project(exprs)?.build()? - } else { - plan - }; Ok(DataFrame { session_state: self.session_state, plan, - projection_requires_validation: !is_grouping_set, + projection_requires_validation: true, }) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6b4d2592f608..7cce5c7d171b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -77,13 +77,16 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; +use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; -use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::aggregate::{ + AggregateExpr, AggregateExprBuilder, GroupingExpr, +}; +use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, }; @@ -722,6 +725,12 @@ impl DefaultPhysicalPlanner { session_state, )?; + let group_by_expr = if groups.is_single() { + None + } else { + Some(group_expr_to_bitmap_index(group_expr)?) + }; + let agg_filter = aggr_expr .iter() .map(|e| { @@ -730,12 +739,32 @@ impl DefaultPhysicalPlanner { logical_input_schema, &physical_input_schema, session_state.execution_props(), + group_by_expr.as_ref(), ) }) .collect::>>()?; - let (mut aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = - multiunzip(agg_filter); + let no_grouping_agg = agg_filter + .iter() + .filter_map(|(e, filters, order_bys)| { + if matches!(e, AggregateExpr::AggregateFunctionExpr(_)) { + Some((e.clone(), filters.clone(), order_bys.clone())) + } else { + None + } + }) + .collect::>(); + + let (aggregates, filters, _order_bys): (Vec<_>, Vec<_>, Vec<_>) = + multiunzip(no_grouping_agg); + + let mut aggregates = aggregates + .into_iter() + .map(|e| match e { + AggregateExpr::AggregateFunctionExpr(e) => e, + _ => unreachable!(), + }) + .collect::>(); let mut async_exprs = Vec::new(); let num_input_columns = physical_input_schema.fields().len(); @@ -813,22 +842,72 @@ impl DefaultPhysicalPlanner { let final_grouping_set = initial_aggr.group_expr().as_final(); - Arc::new(AggregateExec::try_new( + let final_agg = Arc::new(AggregateExec::try_new( next_partition_mode, final_grouping_set, updated_aggregates, filters, initial_aggr, Arc::clone(&physical_input_schema), - )?) + )?); + + if groups.is_single() + && !agg_filter + .iter() + .any(|(e, _, _)| matches!(e, AggregateExpr::GroupingExpr(_))) + { + final_agg + } else { + // Need to project out __grouping_id column and compute GROUPING expressions + let mut proj_exprs = Vec::new(); + let num_group_exprs = groups.expr().len(); + + let schema = final_agg.schema(); + + // Add group columns + for i in 0..num_group_exprs { + let field = schema.field(i); + proj_exprs.push(ProjectionExpr { + expr: Arc::new(Column::new(field.name(), i)), + alias: field.name().to_string(), + }); + } + + // Skip __grouping_id at position num_group_exprs + // Add aggregate expressions (either computed GROUPING or column references) + let mut agg_col_idx = num_group_exprs + 1; // Start after __grouping_id + + for (agg_expr, _, _) in &agg_filter { + match agg_expr { + AggregateExpr::GroupingExpr(grouping_expr) => { + // Use the GroupingExpr directly as a physical expression + proj_exprs.push(ProjectionExpr { + expr: Arc::clone(grouping_expr) + as Arc, + alias: agg_expr.name().to_string(), + }); + } + AggregateExpr::AggregateFunctionExpr(_) => { + // Reference the aggregate function column + let field = schema.field(agg_col_idx); + proj_exprs.push(ProjectionExpr { + expr: Arc::new(Column::new( + field.name(), + agg_col_idx, + )), + alias: field.name().to_string(), + }); + agg_col_idx += 1; + } + } + } + Arc::new(ProjectionExec::try_new(proj_exprs, final_agg)?) + } + } + LogicalPlan::Projection(Projection { input, expr, .. }) => { + let child = children.one()?; + self.create_project_physical_exec(session_state, child, input, expr)? } - LogicalPlan::Projection(Projection { input, expr, .. }) => self - .create_project_physical_exec( - session_state, - children.one()?, - input, - expr, - )?, LogicalPlan::Filter(Filter { predicate, input, .. }) => { @@ -1749,23 +1828,68 @@ pub fn create_window_expr( } type AggregateExprWithOptionalArgs = ( - Arc, + AggregateExpr, // The filter clause, if any Option>, // Expressions in the ORDER BY clause Vec, ); +/// Create a map from grouping expr to index in the internal grouping id. +/// +/// For more details on how the grouping id bitmap works the documentation for +/// [[datafusion_physical_expr::aggregate::INTERNAL_GROUPING_ID]] +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { + Ok(grouping_set_to_exprlist(group_expr)? + .into_iter() + .rev() + .enumerate() + .map(|(idx, v)| (v, idx)) + .collect::>()) +} + /// Create an aggregate expression with a name from a logical expression pub fn create_aggregate_expr_with_name_and_maybe_filter( e: &Expr, name: Option, - human_displan: String, + human_display: String, logical_input_schema: &DFSchema, physical_input_schema: &Schema, execution_props: &ExecutionProps, + group_by_expr: Option<&HashMap<&Expr, usize>>, ) -> Result { + let name = if let Some(name) = name { + name + } else { + physical_name(e)? + }; match e { + Expr::AggregateFunction(AggregateFunction { func, params }) + if func.name() == "grouping" => + { + match group_by_expr { + Some(group_by_expr) => { + let indices = params + .args + .iter() + .map(|expr| match group_by_expr.get(expr) { + Some(idx) => Ok(*idx as i32), + None => plan_err!( + "Grouping function argument {} not in grouping columns", + expr + ), + }) + .collect::>>()?; + let grouping_expr = + GroupingExpr::new(name, human_display, Some(indices)); + Ok((Arc::new(grouping_expr).into(), None, Vec::new())) + } + None => { + let grouping_expr = GroupingExpr::new(name, human_display, None); + Ok((Arc::new(grouping_expr).into(), None, Vec::new())) + } + } + } Expr::AggregateFunction(AggregateFunction { func, params: @@ -1777,12 +1901,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( null_treatment, }, }) => { - let name = if let Some(name) = name { - name - } else { - physical_name(e)? - }; - let physical_args = create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { @@ -1809,7 +1927,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( .order_by(order_bys.clone()) .schema(Arc::new(physical_input_schema.to_owned())) .alias(name) - .human_display(human_displan) + .human_display(human_display) .with_ignore_nulls(ignore_nulls) .with_distinct(*distinct) .build() @@ -1818,7 +1936,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, order_bys) }; - Ok((agg_expr, filter, order_bys)) + Ok(( + AggregateExpr::AggregateFunctionExpr(agg_expr), + filter, + order_bys, + )) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } @@ -1830,6 +1952,7 @@ pub fn create_aggregate_expr_and_maybe_filter( logical_input_schema: &DFSchema, physical_input_schema: &Schema, execution_props: &ExecutionProps, + group_by_expr: Option<&HashMap<&Expr, usize>>, ) -> Result { // Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total" // Some functions like `count_all()` create internal aliases, @@ -1854,6 +1977,7 @@ pub fn create_aggregate_expr_and_maybe_filter( logical_input_schema, physical_input_schema, execution_props, + group_by_expr, ) } @@ -2796,13 +2920,20 @@ mod tests { .build()?; let execution_plan = plan(&logical_plan).await?; - let final_hash_agg = execution_plan + let projection = execution_plan + .as_any() + .downcast_ref::() + .expect("projection"); + + let final_hash_agg = projection + .input() .as_any() .downcast_ref::() .expect("hash aggregate"); + assert_eq!( "sum(aggregate_test_100.c3)", - final_hash_agg.schema().field(3).name() + projection.schema().field(2).name() ); // we need access to the input to the partial aggregate so that other projects can // implement serde diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3cc032277405..5b70a73ce20a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -21,7 +21,7 @@ use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; use super::dml::CopyTo; use super::invariants::{ @@ -3420,22 +3420,11 @@ impl Aggregate { let grouping_expr: Vec<&Expr> = grouping_set_to_exprlist(group_expr.as_slice())?; let mut qualified_fields = exprlist_to_fields(grouping_expr, &input)?; - - // Even columns that cannot be null will become nullable when used in a grouping set. if is_grouping_set { qualified_fields = qualified_fields .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); - qualified_fields.push(( - None, - Field::new( - Self::INTERNAL_GROUPING_ID, - Self::grouping_id_type(qualified_fields.len()), - false, - ) - .into(), - )); } qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?); @@ -3476,6 +3465,25 @@ impl Aggregate { ); } + let group_expr_set = grouping_set_to_exprlist(&group_expr)? + .into_iter() + .collect::>(); + // Validate GROUPING function arguments are all columns and in group_expr + for expr in &aggr_expr { + if let Expr::AggregateFunction(agg_func) = expr { + if agg_func.func.name() == "grouping" { + for arg in &agg_func.params.args { + if !group_expr_set.contains(arg) { + return plan_err!( + "GROUPING function argument {} not in grouping columns", + arg + ); + } + } + } + } + } + let aggregate_func_dependencies = calc_func_dependencies_for_aggregate(&group_expr, &input, &schema)?; let new_schema = schema.as_ref().clone(); @@ -3490,19 +3498,9 @@ impl Aggregate { }) } - fn is_grouping_set(&self) -> bool { - matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) - } - /// Get the output expressions. fn output_expressions(&self) -> Result> { - static INTERNAL_ID_EXPR: LazyLock = LazyLock::new(|| { - Expr::Column(Column::from_name(Aggregate::INTERNAL_GROUPING_ID)) - }); let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; - if self.is_grouping_set() { - exprs.push(&INTERNAL_ID_EXPR); - } exprs.extend(self.aggr_expr.iter()); debug_assert!(exprs.len() == self.schema.fields().len()); Ok(exprs) @@ -3530,25 +3528,6 @@ impl Aggregate { DataType::UInt64 } } - - /// Internal column used when the aggregation is a grouping set. - /// - /// This column contains a bitmask where each bit represents a grouping - /// expression. The least significant bit corresponds to the rightmost - /// grouping expression. A bit value of 0 indicates that the corresponding - /// column is included in the grouping set, while a value of 1 means it is excluded. - /// - /// For example, for the grouping expressions CUBE(a, b), the grouping ID - /// column will have the following values: - /// 0b00: Both `a` and `b` are included - /// 0b01: `b` is excluded - /// 0b10: `a` is excluded - /// 0b11: Both `a` and `b` are excluded - /// - /// This internal column is necessary because excluded columns are replaced - /// with `NULL` values. To handle these cases correctly, we must distinguish - /// between an actual `NULL` value in a column and a column being excluded from the set. - pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; } // Manual implementation needed because of `schema` field. Comparison excludes this field. diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b91db4527b3a..d614f6270573 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -60,7 +60,7 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { ); } // Groupings sets have an additional integral column for the grouping id - Ok(grouping_set.distinct_expr().len() + 1) + Ok(grouping_set.distinct_expr().len()) } else { grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) } diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs index 4d1da1dad594..c970e501d9d3 100644 --- a/datafusion/functions-aggregate/src/grouping.rs +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -35,25 +35,25 @@ make_udaf_expr_and_func!( Grouping, grouping, expression, - "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", grouping_udaf ); #[user_doc( doc_section(label = "General Functions"), - description = "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + description = "Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn).", syntax_example = "grouping(expression)", sql_example = r#"```sql > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ```"#, argument( name = "expression", diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 272692f98368..a6cdbbe480ce 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -28,14 +28,12 @@ use datafusion_common::Result; use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{InvariantLevel, LogicalPlan}; -use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use self::function_rewrite::ApplyFunctionRewrites; pub mod function_rewrite; -pub mod resolve_grouping_function; pub mod type_coercion; /// [`AnalyzerRule`]s transform [`LogicalPlan`]s in some way to make @@ -85,10 +83,8 @@ impl Default for Analyzer { impl Analyzer { /// Create a new analyzer using the recommended list of rules pub fn new() -> Self { - let rules: Vec> = vec![ - Arc::new(ResolveGroupingFunction::new()), - Arc::new(TypeCoercion::new()), - ]; + let rules: Vec> = + vec![Arc::new(TypeCoercion::new())]; Self::with_rules(rules) } diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs deleted file mode 100644 index fa7ff1b8b19d..000000000000 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ /dev/null @@ -1,248 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Analyzed rule to replace TableScan references -//! such as DataFrames and Views and inlines the LogicalPlan. - -use std::cmp::Ordering; -use std::collections::HashMap; -use std::sync::Arc; - -use crate::analyzer::AnalyzerRule; - -use arrow::datatypes::DataType; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{ - internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, -}; -use datafusion_expr::expr::{AggregateFunction, Alias}; -use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::utils::grouping_set_to_exprlist; -use datafusion_expr::{ - bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, - Expr, Projection, -}; -use itertools::Itertools; - -/// Replaces grouping aggregation function with value derived from internal grouping id -#[derive(Default, Debug)] -pub struct ResolveGroupingFunction; - -impl ResolveGroupingFunction { - pub fn new() -> Self { - Self {} - } -} - -impl AnalyzerRule for ResolveGroupingFunction { - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(analyze_internal).data() - } - - fn name(&self) -> &str { - "resolve_grouping_function" - } -} - -/// Create a map from grouping expr to index in the internal grouping id. -/// -/// For more details on how the grouping id bitmap works the documentation for -/// [[Aggregate::INTERNAL_GROUPING_ID]] -fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { - Ok(grouping_set_to_exprlist(group_expr)? - .into_iter() - .rev() - .enumerate() - .map(|(idx, v)| (v, idx)) - .collect::>()) -} - -fn replace_grouping_exprs( - input: Arc, - schema: DFSchemaRef, - group_expr: Vec, - aggr_expr: Vec, -) -> Result { - // Create HashMap from Expr to index in the grouping_id bitmap - let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); - let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; - let columns = schema.columns(); - let mut new_agg_expr = Vec::new(); - let mut projection_exprs = Vec::new(); - let grouping_id_len = if is_grouping_set { 1 } else { 0 }; - let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; - projection_exprs.extend( - columns - .iter() - .take(group_expr_len) - .map(|column| Expr::Column(column.clone())), - ); - for (expr, column) in aggr_expr - .into_iter() - .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) - { - match expr { - Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { - let grouping_expr = grouping_function_on_id( - function, - &group_expr_to_bitmap_index, - is_grouping_set, - )?; - projection_exprs.push(Expr::Alias(Alias::new( - grouping_expr, - column.relation, - column.name, - ))); - } - _ => { - projection_exprs.push(Expr::Column(column)); - new_agg_expr.push(expr); - } - } - } - // Recreate aggregate without grouping functions - let new_aggregate = - LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); - // Create projection with grouping functions calculations - let projection = LogicalPlan::Projection(Projection::try_new( - projection_exprs, - new_aggregate.into(), - )?); - Ok(projection) -} - -fn analyze_internal(plan: LogicalPlan) -> Result> { - // rewrite any subqueries in the plan first - let transformed_plan = - plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; - - let transformed_plan = transformed_plan.transform_data(|plan| match plan { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - .. - }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( - replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, - )), - _ => Ok(Transformed::no(plan)), - })?; - - Ok(transformed_plan) -} - -fn is_grouping_function(expr: &Expr) -> bool { - // TODO: Do something better than name here should grouping be a built - // in expression? - matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") -} - -fn contains_grouping_function(exprs: &[Expr]) -> bool { - exprs.iter().any(is_grouping_function) -} - -/// Validate that the arguments to the grouping function are in the group by clause. -fn validate_args( - function: &AggregateFunction, - group_by_expr: &HashMap<&Expr, usize>, -) -> Result<()> { - let expr_not_in_group_by = function - .params - .args - .iter() - .find(|expr| !group_by_expr.contains_key(expr)); - if let Some(expr) = expr_not_in_group_by { - plan_err!( - "Argument {} to grouping function is not in grouping columns {}", - expr, - group_by_expr.keys().map(|e| e.to_string()).join(", ") - ) - } else { - Ok(()) - } -} - -fn grouping_function_on_id( - function: &AggregateFunction, - group_by_expr: &HashMap<&Expr, usize>, - is_grouping_set: bool, -) -> Result { - validate_args(function, group_by_expr)?; - let args = &function.params.args; - - // Postgres allows grouping function for group by without grouping sets, the result is then - // always 0 - if !is_grouping_set { - return Ok(Expr::Literal(ScalarValue::from(0i32), None)); - } - - let group_by_expr_count = group_by_expr.len(); - let literal = |value: usize| { - if group_by_expr_count < 8 { - Expr::Literal(ScalarValue::from(value as u8), None) - } else if group_by_expr_count < 16 { - Expr::Literal(ScalarValue::from(value as u16), None) - } else if group_by_expr_count < 32 { - Expr::Literal(ScalarValue::from(value as u32), None) - } else { - Expr::Literal(ScalarValue::from(value as u64), None) - } - }; - - let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); - // The grouping call is exactly our internal grouping id - if args.len() == group_by_expr_count - && args - .iter() - .rev() - .enumerate() - .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) - { - return Ok(cast(grouping_id_column, DataType::Int32)); - } - - args.iter() - .rev() - .enumerate() - .map(|(arg_idx, expr)| { - group_by_expr.get(expr).map(|group_by_idx| { - let group_by_bit = - bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); - match group_by_idx.cmp(&arg_idx) { - Ordering::Less => { - bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) - } - Ordering::Greater => { - bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) - } - Ordering::Equal => group_by_bit, - } - }) - }) - .collect::>>() - .and_then(|bit_exprs| { - bit_exprs - .into_iter() - .reduce(bitwise_or) - .map(|expr| cast(expr, DataType::Int32)) - }) - .ok_or_else(|| { - internal_datafusion_err!("Grouping sets should contains at least one element") - }) -} diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index e9a23c7c4dc5..b5d34a7747e6 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -371,7 +371,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64] TableScan: test [a:UInt32, b:UInt32, c:UInt32] " ) @@ -392,7 +392,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64] TableScan: test [a:UInt32, b:UInt32, c:UInt32] " ) @@ -414,7 +414,7 @@ mod tests { assert_optimized_plan_equal!( plan, @r" - Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64] + Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64] TableScan: test [a:UInt32, b:UInt32, c:UInt32] " ) diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs index 19d2ecc924dd..c2b66e71b4aa 100644 --- a/datafusion/physical-expr/src/aggregate.rs +++ b/datafusion/physical-expr/src/aggregate.rs @@ -34,15 +34,22 @@ pub mod utils { }; } +use std::any::Any; use std::fmt::Debug; +use std::ops::Shr; use std::sync::Arc; use crate::expressions::Column; -use arrow::compute::SortOptions; -use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use arrow::array::{ArrowPrimitiveType, Int32Array, PrimitiveArray, RecordBatch}; +use arrow::compute::{unary, SortOptions}; +use arrow::datatypes::{ + ArrowNativeType, DataType, FieldRef, Int32Type, Schema, SchemaRef, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; -use datafusion_expr::{AggregateUDF, ReversedUDAF, SetMonotonicity}; +use datafusion_expr::{AggregateUDF, ColumnarValue, ReversedUDAF, SetMonotonicity}; use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::GroupsAccumulator; use datafusion_expr_common::type_coercion::aggregates::check_arg_count; @@ -744,3 +751,375 @@ fn replace_order_by_clause(order_by: &mut String) { fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) { *aggr_name = aggr_name.replace(fn_name_old, fn_name_new); } + +/// Represents a GROUPING physical expression +/// +/// The GROUPING function returns a bitmask indicating which columns are aggregated +/// in a GROUP BY GROUPING SETS, ROLLUP, or CUBE query. +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct GroupingExpr { + name: String, + /// A human readable name + human_display: String, + /// The grouping id value + indices: Option>, +} + +impl GroupingExpr { + /// Create a new GROUPING physical expression + pub fn new(name: String, human_display: String, indices: Option>) -> Self { + Self { + name, + human_display, + indices, + } + } + + /// Get the indices + pub fn indices(&self) -> &Option> { + &self.indices + } + + fn grouping( + &self, + grouping_id_col: &PrimitiveArray, + indices: &[i32], + ) -> PrimitiveArray + where + T::Native: Shr, + { + unary::<_, _, Int32Type>(grouping_id_col, |grouping_id| { + let mut result = 0i32; + for index in indices.iter() { + let bit = (grouping_id >> *index).as_usize() & 1; + result = (result << 1) | (bit as i32); + } + result + }) + } + + pub fn evaluate(&self, batch: &RecordBatch) -> Result { + PhysicalExpr::evaluate(self, batch) + } +} + +impl std::fmt::Display for GroupingExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + if let Some(indices) = &self.indices { + write!( + f, + "GROUPING({})", + indices + .iter() + .map(|i| i.to_string()) + .collect::>() + .join(",") + )?; + } else { + write!(f, "0")?; + } + + write!(f, " AS {}", self.name) + } +} + +impl PhysicalExpr for GroupingExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Int32) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let Some(indices) = &self.indices else { + return Ok(ColumnarValue::Array(Arc::new(Int32Array::from( + std::vec![0; batch.num_rows()], + )))); + }; + // Get the grouping_id column from the batch + let Some(grouping_id_col) = batch.column_by_name(INTERNAL_GROUPING_ID) else { + return internal_err!( + "GROUPING expression requires {} column in the schema", + INTERNAL_GROUPING_ID + ); + }; + + match grouping_id_col.data_type() { + DataType::UInt8 => { + let result = self + .grouping(as_primitive_array::(grouping_id_col)?, indices); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::UInt16 => { + let result = self.grouping( + as_primitive_array::(grouping_id_col)?, + indices, + ); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::UInt32 => { + let result = self.grouping( + as_primitive_array::(grouping_id_col)?, + indices, + ); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::UInt64 => { + let result = self.grouping( + as_primitive_array::(grouping_id_col)?, + indices, + ); + Ok(ColumnarValue::Array(Arc::new(result))) + } + _ => { + internal_err!( + "GROUPING expression requires a primitive array, but got {}", + grouping_id_col.data_type() + ) + } + } + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{self}")) + } + + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(Arc::new(arrow::datatypes::Field::new( + std::format!("{self}"), + self.data_type(input_schema)?, + self.nullable(input_schema)?, + ))) + } + + fn evaluate_selection( + &self, + batch: &RecordBatch, + selection: &arrow::array::BooleanArray, + ) -> Result { + let tmp_batch = arrow::compute::filter_record_batch(batch, selection)?; + + let tmp_result = self.evaluate(&tmp_batch)?; + + if batch.num_rows() == tmp_batch.num_rows() { + // All values from the `selection` filter are true. + Ok(tmp_result) + } else if let ColumnarValue::Array(a) = tmp_result { + datafusion_physical_expr_common::utils::scatter(selection, a.as_ref()) + .map(ColumnarValue::Array) + } else if let ColumnarValue::Scalar(ScalarValue::Boolean(value)) = &tmp_result { + // When the scalar is true or false, skip the scatter process + if let Some(v) = value { + if *v { + Ok(ColumnarValue::from( + Arc::new(selection.clone()) as arrow::array::ArrayRef + )) + } else { + Ok(tmp_result) + } + } else { + let array = + arrow::array::BooleanArray::from(std::vec![None; batch.num_rows()]); + datafusion_physical_expr_common::utils::scatter(selection, &array) + .map(ColumnarValue::Array) + } + } else { + Ok(tmp_result) + } + } + + fn evaluate_bounds( + &self, + _children: &[&datafusion_expr::interval_arithmetic::Interval], + ) -> Result { + not_impl_err!("Not implemented for {self}") + } + + fn propagate_constraints( + &self, + _interval: &datafusion_expr::interval_arithmetic::Interval, + _children: &[&datafusion_expr::interval_arithmetic::Interval], + ) -> Result>> { + Ok(Some(std::vec![])) + } + + fn evaluate_statistics( + &self, + children: &[&datafusion_expr::statistics::Distribution], + ) -> Result { + let children_ranges = children + .iter() + .map(|c| c.range()) + .collect::>>()?; + let children_ranges_refs = children_ranges.iter().collect::>(); + let output_interval = self.evaluate_bounds(children_ranges_refs.as_slice())?; + let dt = output_interval.data_type(); + if dt.eq(&DataType::Boolean) { + let p = if output_interval + .eq(&datafusion_expr::interval_arithmetic::Interval::CERTAINLY_TRUE) + { + ScalarValue::new_one(&dt) + } else if output_interval + .eq(&datafusion_expr::interval_arithmetic::Interval::CERTAINLY_FALSE) + { + ScalarValue::new_zero(&dt) + } else { + ScalarValue::try_from(&dt) + }?; + datafusion_expr::statistics::Distribution::new_bernoulli(p) + } else { + datafusion_expr::statistics::Distribution::new_from_interval(output_interval) + } + } + + fn propagate_statistics( + &self, + parent: &datafusion_expr::statistics::Distribution, + children: &[&datafusion_expr::statistics::Distribution], + ) -> Result>> { + let children_ranges = children + .iter() + .map(|c| c.range()) + .collect::>>()?; + let children_ranges_refs = children_ranges.iter().collect::>(); + let parent_range = parent.range()?; + let Some(propagated_children) = + self.propagate_constraints(&parent_range, children_ranges_refs.as_slice())? + else { + return Ok(None); + }; + itertools::izip!(propagated_children.into_iter(), children_ranges, children) + .map(|(new_interval, old_interval, child)| { + if new_interval == old_interval { + // We weren't able to narrow the range, preserve the old statistics. + Ok((*child).clone()) + } else if new_interval.data_type().eq(&DataType::Boolean) { + let dt = old_interval.data_type(); + let p = if new_interval.eq(&datafusion_expr::interval_arithmetic::Interval::CERTAINLY_TRUE) { + ScalarValue::new_one(&dt) + } else if new_interval.eq(&datafusion_expr::interval_arithmetic::Interval::CERTAINLY_FALSE) { + ScalarValue::new_zero(&dt) + } else { + std::unreachable!("Given that we have a range reduction for a boolean interval, we should have certainty") + }?; + datafusion_expr::statistics::Distribution::new_bernoulli(p) + } else { + datafusion_expr::statistics::Distribution::new_from_interval(new_interval) + } + }) + .collect::>() + .map(Some) + } + + fn get_properties( + &self, + _children: &[datafusion_expr::sort_properties::ExprProperties], + ) -> Result { + Ok(datafusion_expr::sort_properties::ExprProperties::new_unknown()) + } + + fn snapshot(&self) -> Result>> { + // By default, we return None to indicate that this PhysicalExpr does not + // have any dynamic references or state. + // This is a safe default behavior. + Ok(None) + } + + fn snapshot_generation(&self) -> u64 { + // By default, we return 0 to indicate that this PhysicalExpr does not + // have any dynamic references or state. + // Since the recursive algorithm XORs the generations of all children the overall + // generation will be 0 if no children have a non-zero generation, meaning that + // static expressions will always return 0. + 0 + } + + fn is_volatile_node(&self) -> bool { + false + } +} + +#[derive(Debug, PartialEq, Clone)] +pub enum AggregateExpr { + GroupingExpr(Arc), + AggregateFunctionExpr(Arc), +} + +impl AggregateExpr { + pub fn get_minmax_desc(&self) -> Option<(FieldRef, bool)> { + match self { + AggregateExpr::AggregateFunctionExpr(expr) => expr.get_minmax_desc(), + _ => None, + } + } + + pub fn order_bys(&self) -> &[PhysicalSortExpr] { + match self { + AggregateExpr::AggregateFunctionExpr(expr) => expr.order_bys(), + _ => &[], + } + } + + pub fn name(&self) -> &str { + match self { + AggregateExpr::GroupingExpr(expr) => &expr.name, + AggregateExpr::AggregateFunctionExpr(expr) => &expr.name, + } + } + + pub fn human_display(&self) -> &str { + match self { + AggregateExpr::GroupingExpr(expr) => &expr.human_display, + AggregateExpr::AggregateFunctionExpr(expr) => &expr.human_display, + } + } +} + +impl From> for AggregateExpr { + fn from(expr: Arc) -> Self { + AggregateExpr::AggregateFunctionExpr(expr) + } +} + +impl From> for AggregateExpr { + fn from(expr: Arc) -> Self { + AggregateExpr::GroupingExpr(expr) + } +} + +/// Internal column used when the aggregation is a grouping set. +/// +/// This column contains a bitmask where each bit represents a grouping +/// expression. The least significant bit corresponds to the rightmost +/// grouping expression. A bit value of 0 indicates that the corresponding +/// column is included in the grouping set, while a value of 1 means it is excluded. +/// +/// For example, for the grouping expressions CUBE(a, b), the grouping ID +/// column will have the following values: +/// 0b00: Both `a` and `b` are included +/// 0b01: `b` is excluded +/// 0b10: `a` is excluded +/// 0b11: Both `a` and `b` are excluded +/// +/// This internal column is necessary because excluded columns are replaced +/// with `NULL` values. To handle these cases correctly, we must distinguish +/// between an actual `NULL` value in a column and a column being excluded from the set. +pub const INTERNAL_GROUPING_ID: &str = "__grouping_id"; diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs index 3666ff3798b6..92c3312757bc 100644 --- a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::ConfigOptions; @@ -136,7 +137,9 @@ impl LimitedDistinctAggregation { } } } - rewrite_applicable = false; + if plan.as_any().downcast_ref::().is_none() { + rewrite_applicable = false; + } Ok(Transformed::no(plan)) }; let child = child.to_owned().transform_down(closure).data().ok()?; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 878bccc1d177..35fe1edeb814 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -41,7 +41,7 @@ use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Constraint, Constraints, Result}; use datafusion_execution::TaskContext; use datafusion_expr::{Accumulator, Aggregate}; -use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::aggregate::{AggregateFunctionExpr, INTERNAL_GROUPING_ID}; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ @@ -260,10 +260,8 @@ impl PhysicalGroupBy { .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), ); if !self.is_single() { - output_exprs.push(Arc::new(Column::new( - Aggregate::INTERNAL_GROUPING_ID, - self.expr.len(), - )) as _); + output_exprs + .push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _); } output_exprs } @@ -300,7 +298,7 @@ impl PhysicalGroupBy { if !self.is_single() { fields.push( Field::new( - Aggregate::INTERNAL_GROUPING_ID, + INTERNAL_GROUPING_ID, Aggregate::grouping_id_type(self.expr.len()), false, ) @@ -323,15 +321,16 @@ impl PhysicalGroupBy { /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial /// aggregation. pub fn as_final(&self) -> PhysicalGroupBy { - let expr: Vec<_> = - self.output_exprs() - .into_iter() - .zip( - self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once( - Aggregate::INTERNAL_GROUPING_ID.to_owned(), - )), - ) - .collect(); + let expr: Vec<_> = self + .output_exprs() + .into_iter() + .zip( + self.expr + .iter() + .map(|t| t.1.clone()) + .chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())), + ) + .collect(); let num_exprs = expr.len(); let groups = if self.expr.is_empty() { // No GROUP BY expressions - should have no groups diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 8b3791017a8a..f7086c702009 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -243,14 +243,7 @@ fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result Ok(grouping_expr.into_iter().nth(index)), - Ordering::Equal => { - internal_err!( - "Tried to unproject column referring to internal grouping id" - ) - } - Ordering::Greater => { - Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1)) - } + _ => Ok(agg.aggr_expr.get(index - grouping_expr.len())), } } else { Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)) diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 7aa982dcf3dd..4e4da6473434 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -1924,11 +1924,11 @@ fn test_aggregation_to_sql() { rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1, rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2 FROM person - GROUP BY id, first_name"#; + GROUP BY id, first_name, age"#; let statement = generate_round_trip_statement(GenericDialect {}, sql); assert_snapshot!( statement, - @"SELECT person.id, person.first_name, sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, sum(person.id) FILTER (WHERE ((person.id > 50) AND (person.first_name = 'John'))) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS filtered_sum, max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 FROM person GROUP BY person.id, person.first_name", + @"SELECT person.id, person.first_name, sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, sum(person.id) FILTER (WHERE ((person.id > 50) AND (person.first_name = 'John'))) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS filtered_sum, max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 FROM person GROUP BY person.id, person.first_name, person.age", ); } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 9d6c7b11add6..2813a1a68f9e 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5753,10 +5753,9 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -01)Projection: aggregate_test_100.c2, aggregate_test_100.c3 -02)--Limit: skip=0, fetch=3 -03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -04)------TableScan: aggregate_test_100 projection=[c2, c3] +01)Limit: skip=0, fetch=3 +02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +03)----TableScan: aggregate_test_100 projection=[c2, c3] physical_plan 01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3] 02)--GlobalLimitExec: skip=0, fetch=3 diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index a3b6d40aea2d..1532197ded4e 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -174,7 +174,6 @@ EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test initial_logical_plan 01)Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c 02)--TableScan: simple_explain_test -logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after eliminate_nested_union SAME TEXT AS ABOVE @@ -538,7 +537,6 @@ EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test initial_logical_plan 01)Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c 02)--TableScan: simple_explain_test -logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after eliminate_nested_union SAME TEXT AS ABOVE diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 39e4984ab9f7..45d90cdd30af 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -311,9 +311,8 @@ async fn aggregate_grouping_rollup() -> Result<()> { assert_snapshot!( plan, @r#" - Projection: data.a, data.c, data.e, avg(data.b) - Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]] - TableScan: data projection=[a, b, c, e] + Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]] + TableScan: data projection=[a, b, c, e] "# ); Ok(()) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 205962031b1d..57f441d8cec4 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -267,7 +267,7 @@ first_value(expression [ORDER BY expression]) ### `grouping` -Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. +Returns the level of grouping, equals to (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + … + grouping(cn). ```sql grouping(expression) @@ -283,13 +283,13 @@ grouping(expression) > SELECT column_name, GROUPING(column_name) AS group_column FROM table_name GROUP BY GROUPING SETS ((column_name), ()); -+-------------+-------------+ ++-------------+--------------+ | column_name | group_column | -+-------------+-------------+ -| value1 | 0 | -| value2 | 0 | -| NULL | 1 | -+-------------+-------------+ ++-------------+--------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+--------------+ ``` ### `last_value`