From b4b297e9f73add0c4c81c1cd7ecce77449d46875 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Thu, 24 Jul 2025 21:55:50 +0800 Subject: [PATCH 1/6] pivot --- datafusion/core/src/dataframe/mod.rs | 48 ++++++++ datafusion/expr/src/logical_plan/builder.rs | 118 +++++++++++++++++++- datafusion/sql/src/expr/identifier.rs | 2 +- datafusion/sql/src/relation/mod.rs | 73 +++++++++++- 4 files changed, 236 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index a19e6f558162..06b3a7c9c3da 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2359,6 +2359,54 @@ impl DataFrame { let df = ctx.read_batch(batch)?; Ok(df) } + + /// Pivot the DataFrame, transforming rows into columns based on the specified value columns and aggregation functions. + /// + /// # Arguments + /// * `aggregate_functions` - Aggregation expressions to apply (e.g., sum, count). + /// * `value_column` - Columns whose unique values will become new columns in the output. + /// * `value_source` - Columns to use as values for the pivoted columns. + /// * `default_on_null` - Optional expressions to use as default values when a pivoted value is null. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::arrow::array::{Int32Array, StringArray}; + /// # use std::sync::Arc; + /// # let ctx = SessionContext::new(); + /// # let batch = RecordBatch::try_new( + /// # Arc::new(Schema::new(vec![ + /// # Field::new("category", DataType::Utf8, false), + /// # Field::new("value", DataType::Int32, false), + /// # ])), + /// # vec![ + /// # Arc::new(StringArray::from(vec!["A", "B", "A"])), + /// # Arc::new(Int32Array::from(vec![1, 2, 3])) + /// # ] + /// # ).unwrap(); + /// # let df = ctx.read_batch(batch).unwrap(); + /// // Pivot the DataFrame so each unique category becomes a column + /// let pivoted = df.pivot( + /// vec![sum(col("value"))], + /// vec![col("category")], + /// vec![col("value")], + /// None + /// ).unwrap(); + /// ``` + pub fn pivot(self, + aggregate_functions: Vec, + value_column: Vec, + value_source: Vec, + default_on_null: Option>) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .pivot(aggregate_functions, value_column, value_source, default_on_null)? + .build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: self.projection_requires_validation, + }) + } } /// Macro for creating DataFrame. diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 1ab5ffa75842..9e8b507acac9 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -43,15 +43,16 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, DmlStatement, ExplainOption, Expr, ExprSchemable, Operator, - RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, lit, when, DmlStatement, ExplainOption, Expr, ExprSchemable, Literal, Operator, RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp }; use super::dml::InsertOp; +use arrow::array::Array; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ exec_err, get_target_functional_dependencies, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, NullEquality, @@ -1480,6 +1481,119 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + pub fn pivot(self, + aggregate_functions: Vec, + value_column: Vec, + value_source: Vec, + default_on_null: Option> + ) -> Result { + match default_on_null { + Some(default_values) if default_values.len() != aggregate_functions.len() => { + return plan_err!("Number of default values must match the number of aggregate functions"); + } + _ => {} + } + let mut used_columns = HashSet::new(); + used_columns.extend(value_column.iter().cloned()); + for agg in aggregate_functions.iter() { + expr_to_columns(agg, &mut used_columns)?; + } + + let used_columns = used_columns + .iter() + .map(|c| c.name.clone()) + .collect::>(); + + // Extract group by columns (all columns not involved in aggregation or pivot) + let schema = self.schema(); + + let group_by_columns = schema.fields().iter().filter_map(|f| { + if used_columns.contains(f.name()) { + None // Skip columns that are used in aggregation or pivot + } else { + Some(Expr::Column(Column::from_name(f.name()))) + } + }).collect::>(); + + // Create filtered aggregate expressions for each value in value_source + let mut aggr_exprs = Vec::new(); + + for value in &value_source { + let (value, value_alias) = if let Expr::Alias(Alias{expr, name, ..}) = value { + (expr.as_ref(), name) + } else { + (value, &value.to_string()) + }; + let condition = match value_column.len() { + 0 => return plan_err!("Pivot requires at least one value column"), + 1 => binary_expr( + Expr::Column(value_column[0].clone()), + Operator::IsNotDistinctFrom, + value.clone() + ), + _ => { + let Expr::Literal(ScalarValue::List(list), _) = value else { + return plan_err!("Pivot value must be a list of values if multiple value columns are provided"); + }; + if list.len() != value_column.len() { + return plan_err!("Pivot value list length must match value column count"); + } + let mut condition: Option = None; + for (idx, col) in value_column.iter().enumerate() { + let single_condition = binary_expr( + Expr::Column(col.clone()), + Operator::IsNotDistinctFrom, + ScalarValue::try_from_array(list.as_ref(), idx)?.lit() + ); + condition = match condition { + None => Some(single_condition), + Some(prev) => Some(and(prev, single_condition)) + }; + } + match condition { + None => return plan_err!("Pivot value condition cannot be empty"), + Some(cond) => cond, + } + }, + }; + + for (i, agg_func) in aggregate_functions.iter().enumerate() { + let Expr::Alias(Alias { expr, name, metadata, .. }) = agg_func else { + return plan_err!("Aggregate function must has an alias"); + }; + let expr = expr.clone().transform(|nested_expr| { + match &nested_expr { + Expr::AggregateFunction(func) => { + let filter = match &func.params.filter { + Some(filter) => and(filter.as_ref().clone(), condition.clone()), + None => condition.clone(), + }; + let mut func = func.clone(); + func.params.filter = Some(Box::new(filter)); + Ok(Transformed::yes(Expr::AggregateFunction(func))) + } + _ => { + Ok(Transformed::no(nested_expr)) + } + } + })?.data; + + let expr = match default_on_null.as_ref() { + Some(default_values) => { + when(expr.clone().is_null(), default_values[i].clone()).otherwise(expr)? + }, + None => expr + }; + let pivot_col_name = format!("{}_{}", value_alias.replace("\"", "").replace("'", ""), name); + aggr_exprs.push(expr.alias_with_metadata(pivot_col_name, metadata.clone())); + } + } + + let aggregate_plan = self.aggregate(group_by_columns, aggr_exprs)?; + + Ok(aggregate_plan) + } } impl From for LogicalPlanBuilder { diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 434ac50bce50..71d093143e20 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -28,7 +28,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_expr::UNNAMED_TABLE; impl SqlToRel<'_, S> { - pub(super) fn sql_identifier_to_expr( + pub(crate) fn sql_identifier_to_expr( &self, id: Ident, schema: &DFSchema, diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index aa37d74fd4d8..e0c419f52466 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -19,14 +19,15 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::array::Array; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, + not_impl_err, plan_err, DFSchema, Diagnostic, Result, ScalarValue, Span, Spans, TableReference }; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; -use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; +use sqlparser::ast::{FunctionArg, FunctionArgExpr, PivotValueSource, Spanned, TableFactor}; mod join; @@ -183,6 +184,74 @@ impl SqlToRel<'_, S> { .build()?; (plan, alias) } + TableFactor::Pivot { table, aggregate_functions, value_column, value_source, default_on_null, alias } => { + let plan = self.create_relation(table.as_ref().clone(), planner_context)?; + let schema = plan.schema(); + let aggregate_functions = aggregate_functions + .into_iter() + .map(|func| { + self.sql_expr_to_logical_expr( + func.expr, + schema, + planner_context, + ).map(|expr| + match func.alias { + Some(name) => expr.alias(name.value), + None => expr, + } + ) + }) + .collect::>>()?; + let value_column = value_column.into_iter().map(|id| { + match self.sql_identifier_to_expr(id, schema, planner_context)? { + Expr::Column(col) => Ok(col), + expr => plan_err!( + "Expected column identifier, found: {expr:?} in pivot value column" + ), + } + }).collect::>>()?; + + let PivotValueSource::List(source) = value_source else { + // Dynamic pivot: the output schema is determined by the data in the source table at runtime. + return plan_err!("Dynamic pivot is not supported yet"); + }; + let value_source = source.into_iter().map(|expr_with_alias| { + self.sql_expr_to_logical_expr( + expr_with_alias.expr, + schema, + planner_context, + ).map(|expr| + match expr_with_alias.alias { + Some(name) => expr.alias(name.value), + None => expr, + } + ) + }).collect::>>()?; + + let default_on_null = default_on_null + .map(|expr| { + let expr = self.sql_expr_to_logical_expr( + expr, + schema, + planner_context, + )?; + match expr { + Expr::Literal(ScalarValue::List(list), _) => { + (0..list.len()).map(|idx| Ok(Expr::Literal(ScalarValue::try_from_array(list.values(), idx)?, None))).collect::>>() + } + _ => return plan_err!("Pivot default value cannot be NULL"), + } + }) + .transpose()?; + + let plan = LogicalPlanBuilder::from(plan).pivot( + aggregate_functions, + value_column, + value_source, + default_on_null, + )?.build()?; + (plan, alias) + } // @todo Support TableFactory::TableFunction? _ => { return not_impl_err!( From cc84b7ecfc9ac961bc5b76ff5bb9558e453fad26 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 5 Oct 2025 17:23:04 +0800 Subject: [PATCH 2/6] chore: upgrade sqlparser --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- datafusion/sql/src/expr/function.rs | 5 +++++ datafusion/sql/src/expr/mod.rs | 6 +++-- datafusion/sql/src/planner.rs | 18 ++++++++++++--- datafusion/sql/src/query.rs | 2 +- datafusion/sql/src/relation/join.rs | 4 +++- datafusion/sql/src/statement.rs | 34 ++++++++++++++++++++++++++++- datafusion/sql/src/unparser/expr.rs | 12 +++++++--- datafusion/sql/src/unparser/plan.rs | 2 +- 10 files changed, 74 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a2939f425712..e8ef3336074f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6001,9 +6001,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.58.0" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec4b661c54b1e4b603b37873a18c59920e4c51ea8ea2cf527d925424dbd4437c" +checksum = "4591acadbcf52f0af60eafbb2c003232b2b4cd8de5f0e9437cb8b1b59046cc0f" dependencies = [ "log", "recursive", diff --git a/Cargo.toml b/Cargo.toml index e5fda30f944c..41a24c948f96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -172,7 +172,7 @@ recursive = "0.1.1" regex = "1.11" rstest = "0.25.0" serde_json = "1" -sqlparser = { version = "0.58.0", default-features = false, features = ["std", "visitor"] } +sqlparser = { version = "0.59.0", default-features = false, features = ["std", "visitor"] } tempfile = "3" testcontainers = { version = "0.24", features = ["default"] } testcontainers-modules = { version = "0.12" } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index a61967ed6957..eabf645a5eaf 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -182,6 +182,11 @@ impl FunctionArgs { "Calling {name}: JSON NULL clause not supported in function arguments: {jn}" ) } + FunctionArgumentClause::JsonReturningClause(jr) => { + return not_impl_err!( + "Calling {name}: JSON RETURNING clause not supported in function arguments: {jr}" + ) + }, } } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ae4cddc61f54..d6c98b353fd2 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -22,7 +22,7 @@ use datafusion_expr::planner::{ use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, - StructField, Subscript, TrimWhereField, Value, ValueWithSpan, + StructField, Subscript, TrimWhereField, TypedString, Value, ValueWithSpan, }; use datafusion_common::{ @@ -291,7 +291,9 @@ impl SqlToRel<'_, S> { ))) } - SQLExpr::TypedString { data_type, value } => Ok(Expr::Cast(Cast::new( + SQLExpr::TypedString(TypedString { + data_type, value, .. + }) => Ok(Expr::Cast(Cast::new( Box::new(lit(value.into_string().unwrap())), self.convert_data_type(&data_type)?, ))), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 79e8cd8e123f..e93c5e066b66 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -720,10 +720,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { (Some(precision), Some(scale)) } }; - make_decimal_type(precision, scale) + make_decimal_type(precision, scale.map(|s| s as u64)) } SQLDataType::Bytea => Ok(DataType::Binary), - SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Interval { fields, precision } => { + if fields.is_some() || precision.is_some() { + return not_impl_err!("Unsupported SQL type {sql_type}"); + } + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } SQLDataType::Struct(fields, _) => { let fields = fields .iter() @@ -818,7 +823,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::NamedTable { .. } | SQLDataType::TsVector | SQLDataType::TsQuery - | SQLDataType::GeometricType(_) => { + | SQLDataType::GeometricType(_) + | SQLDataType::DecimalUnsigned(_) // deprecated mysql type + | SQLDataType::FloatUnsigned(_) // deprecated mysql type + | SQLDataType::RealUnsigned // deprecated mysql type + | SQLDataType::DecUnsigned(_) // deprecated mysql type + | SQLDataType::DoubleUnsigned(_) // deprecated mysql type + | SQLDataType::DoublePrecisionUnsigned // deprecated mysql type + => { not_impl_err!("Unsupported SQL type {sql_type}") } } diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 633d933eb845..54c301625322 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -390,7 +390,7 @@ pub(crate) fn to_order_by_exprs_with_select( quote_style: None, span: Span::empty(), }), - options: order_by_options.clone(), + options: order_by_options, with_fill: None, }), // TODO: Support other types of expressions diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 10491963e3ce..54f8b0d7ec87 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -95,7 +95,9 @@ impl SqlToRel<'_, S> { JoinOperator::FullOuter(constraint) => { self.parse_join(left, right, constraint, JoinType::Full, planner_context) } - JoinOperator::CrossJoin => self.parse_cross_join(left, right), + JoinOperator::CrossJoin(JoinConstraint::None) => { + self.parse_cross_join(left, right) + } other => not_impl_err!("Unsupported JOIN operator {other:?}"), } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 44e924614208..7e746e896384 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -302,6 +302,13 @@ impl SqlToRel<'_, S> { storage_serialization_policy, inherits, table_options: CreateTableOptions::None, + dynamic, + version, + target_lag, + warehouse, + refresh_mode, + initialize, + require_user, }) => { if temporary { return not_impl_err!("Temporary tables not supported")?; @@ -428,7 +435,27 @@ impl SqlToRel<'_, S> { if inherits.is_some() { return not_impl_err!("Table inheritance not supported")?; } - + if dynamic { + return not_impl_err!("Dynamic tables not supported")?; + } + if version.is_some() { + return not_impl_err!("Version not supported")?; + } + if target_lag.is_some() { + return not_impl_err!("Target lag not supported")?; + } + if warehouse.is_some() { + return not_impl_err!("Warehouse not supported")?; + } + if refresh_mode.is_some() { + return not_impl_err!("Refresh mode not supported")?; + } + if initialize.is_some() { + return not_impl_err!("Initialize not supported")?; + } + if require_user { + return not_impl_err!("Require user not supported")?; + } // Merge inline constraints and existing constraints let mut all_constraints = constraints; let inline_constraints = calc_inline_constraints_from_columns(&columns); @@ -534,6 +561,8 @@ impl SqlToRel<'_, S> { to, params, or_alter, + secure, + name_before_not_exists, } => { if materialized { return not_impl_err!("Materialized views not supported")?; @@ -571,6 +600,8 @@ impl SqlToRel<'_, S> { to, params, or_alter, + secure, + name_before_not_exists, }; let sql = stmt.to_string(); let Statement::CreateView { @@ -975,6 +1006,7 @@ impl SqlToRel<'_, S> { selection, returning, or, + limit: _, } => { let from_clauses = from.map(|update_table_from_kind| match update_table_from_kind { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8f5b9cef089e..a7fe8efa153c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1689,7 +1689,7 @@ impl Unparser<'_> { DataType::Float16 => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Float32 => Ok(ast::DataType::Float(None)), + DataType::Float32 => Ok(ast::DataType::Float(ast::ExactNumberInfo::None)), DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), DataType::Timestamp(time_unit, tz) => { Ok(self.dialect.timestamp_cast_dtype(time_unit, tz)) @@ -1705,7 +1705,10 @@ impl Unparser<'_> { DataType::Duration(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } - DataType::Interval(_) => Ok(ast::DataType::Interval), + DataType::Interval(_) => Ok(ast::DataType::Interval { + fields: None, + precision: None, + }), DataType::Binary => { not_impl_err!("Unsupported DataType: conversion: {data_type}") } @@ -1755,7 +1758,10 @@ impl Unparser<'_> { } Ok(ast::DataType::Decimal( - ast::ExactNumberInfo::PrecisionAndScale(new_precision, new_scale), + ast::ExactNumberInfo::PrecisionAndScale( + new_precision, + new_scale as i64, + ), )) } DataType::Map(_, _) => { diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 3826ef9feab2..b6c65614995a 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -1251,7 +1251,7 @@ impl Unparser<'_> { ast::JoinConstraint::None => { // Inner joins with no conditions or filters are not valid SQL in most systems, // return a CROSS JOIN instead - ast::JoinOperator::CrossJoin + ast::JoinOperator::CrossJoin(constraint) } }, JoinType::Left => ast::JoinOperator::LeftOuter(constraint), From b988225c7b6f6f983f1a3db194accd4c5aaa2c4f Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Sun, 5 Oct 2025 18:01:24 +0800 Subject: [PATCH 3/6] chore: upgrade sqlparser --- datafusion/sql/src/parser.rs | 8 +++++--- datafusion/sql/src/statement.rs | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 06418fa4ec96..271ad8a856b4 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -1076,7 +1076,9 @@ mod tests { use super::*; use datafusion_common::assert_contains; use sqlparser::ast::Expr::Identifier; - use sqlparser::ast::{BinaryOperator, DataType, Expr, Ident, ValueWithSpan}; + use sqlparser::ast::{ + BinaryOperator, DataType, ExactNumberInfo, Expr, Ident, ValueWithSpan, + }; use sqlparser::dialect::SnowflakeDialect; use sqlparser::tokenizer::Span; @@ -1573,7 +1575,7 @@ mod tests { name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), - make_column_def("c2", DataType::Float(None)), + make_column_def("c2", DataType::Float(ExactNumberInfo::None)), ], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), @@ -1641,7 +1643,7 @@ mod tests { name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), - make_column_def("c2", DataType::Float(None)), + make_column_def("c2", DataType::Float(ExactNumberInfo::None)), ], file_type: "PARQUET".to_string(), location: "foo.parquet".into(), diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 7e746e896384..0e868e8c2689 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1006,7 +1006,7 @@ impl SqlToRel<'_, S> { selection, returning, or, - limit: _, + limit, } => { let from_clauses = from.map(|update_table_from_kind| match update_table_from_kind { @@ -1024,6 +1024,9 @@ impl SqlToRel<'_, S> { if or.is_some() { plan_err!("ON conflict not supported")?; } + if limit.is_some() { + return not_impl_err!("Update-limit clause not supported")?; + } self.update_to_plan(table, assignments, update_from, selection) } From 9becd14e4e84913870e5ca0aaa048f5a1b629a1a Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 6 Oct 2025 08:31:01 +0800 Subject: [PATCH 4/6] feat: support pivot --- datafusion/core/src/dataframe/mod.rs | 17 ++- datafusion/expr/src/logical_plan/builder.rs | 124 ++++++++++++------- datafusion/sql/src/relation/mod.rs | 89 ++++++++----- datafusion/sqllogictest/test_files/pivot.slt | 56 +++++++++ 4 files changed, 200 insertions(+), 86 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/pivot.slt diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 1ad1b2d7dca6..c4cc20973980 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2402,13 +2402,20 @@ impl DataFrame { /// None /// ).unwrap(); /// ``` - pub fn pivot(self, - aggregate_functions: Vec, - value_column: Vec, + pub fn pivot( + self, + aggregate_functions: Vec, + value_column: Vec, value_source: Vec, - default_on_null: Option>) -> Result { + default_on_null: Option>, + ) -> Result { let plan = LogicalPlanBuilder::from(self.plan) - .pivot(aggregate_functions, value_column, value_source, default_on_null)? + .pivot( + aggregate_functions, + value_column, + value_source, + default_on_null, + )? .build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 8e27eb4b4f26..4beba715d92a 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -25,7 +25,9 @@ use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::{Alias, FieldMetadata, PlannedReplaceSelectItem, Sort as SortExpr}; +use crate::expr::{ + Alias, FieldMetadata, PlannedReplaceSelectItem, ScalarFunction, Sort as SortExpr, +}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, @@ -44,11 +46,12 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, when, DmlStatement, ExplainOption, Expr, ExprSchemable, Literal, Operator, RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp + and, binary_expr, lit, when, DmlStatement, ExplainOption, Expr, ExprSchemable, + Operator, RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, + WriteOp, }; use super::dml::InsertOp; -use arrow::array::Array; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; @@ -1495,14 +1498,15 @@ impl LogicalPlanBuilder { .map(Self::new) } - pub fn pivot(self, - aggregate_functions: Vec, - value_column: Vec, - value_source: Vec, - default_on_null: Option> + pub fn pivot( + self, + aggregate_functions: Vec, + value_column: Vec, + value_source: Vec, + default_on_null: Option>, ) -> Result { match default_on_null { - Some(default_values) if default_values.len() != aggregate_functions.len() => { + Some(default_values) if default_values.len() != aggregate_functions.len() => { return plan_err!("Number of default values must match the number of aggregate functions"); } _ => {} @@ -1521,90 +1525,116 @@ impl LogicalPlanBuilder { // Extract group by columns (all columns not involved in aggregation or pivot) let schema = self.schema(); - let group_by_columns = schema.fields().iter().filter_map(|f| { - if used_columns.contains(f.name()) { - None // Skip columns that are used in aggregation or pivot - } else { - Some(Expr::Column(Column::from_name(f.name()))) - } - }).collect::>(); + let group_by_columns = schema + .fields() + .iter() + .filter_map(|f| { + if used_columns.contains(f.name()) { + None // Skip columns that are used in aggregation or pivot + } else { + Some(Expr::Column(Column::from_name(f.name()))) + } + }) + .collect::>(); // Create filtered aggregate expressions for each value in value_source let mut aggr_exprs = Vec::new(); - + for value in &value_source { - let (value, value_alias) = if let Expr::Alias(Alias{expr, name, ..}) = value { - (expr.as_ref(), name) - } else { - (value, &value.to_string()) - }; + let (value, value_alias) = + if let Expr::Alias(Alias { expr, name, .. }) = value { + (expr.as_ref(), name) + } else { + (value, &value.to_string()) + }; let condition = match value_column.len() { 0 => return plan_err!("Pivot requires at least one value column"), 1 => binary_expr( Expr::Column(value_column[0].clone()), Operator::IsNotDistinctFrom, - value.clone() + value.clone(), ), _ => { - let Expr::Literal(ScalarValue::List(list), _) = value else { - return plan_err!("Pivot value must be a list of values if multiple value columns are provided"); + let Expr::ScalarFunction(ScalarFunction { func, args }) = value + else { + return plan_err!("Pivot value must be struct(literals) if multiple value columns are provided"); }; - if list.len() != value_column.len() { - return plan_err!("Pivot value list length must match value column count"); + if func.name() != "struct" { + return plan_err!("Pivot value must be struct(literals) if multiple value columns are provided"); + } + if args.len() != value_column.len() { + return plan_err!( + "Pivot value list length must match value column count" + ); } let mut condition: Option = None; for (idx, col) in value_column.iter().enumerate() { let single_condition = binary_expr( Expr::Column(col.clone()), Operator::IsNotDistinctFrom, - ScalarValue::try_from_array(list.as_ref(), idx)?.lit() + args[idx].clone(), ); condition = match condition { None => Some(single_condition), - Some(prev) => Some(and(prev, single_condition)) + Some(prev) => Some(and(prev, single_condition)), }; } match condition { - None => return plan_err!("Pivot value condition cannot be empty"), + None => { + return plan_err!("Pivot value condition cannot be empty") + } Some(cond) => cond, } - }, + } }; for (i, agg_func) in aggregate_functions.iter().enumerate() { - let Expr::Alias(Alias { expr, name, metadata, .. }) = agg_func else { + let Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) = agg_func + else { return plan_err!("Aggregate function must has an alias"); }; - let expr = expr.clone().transform(|nested_expr| { - match &nested_expr { + let expr = expr + .clone() + .transform(|nested_expr| match &nested_expr { Expr::AggregateFunction(func) => { let filter = match &func.params.filter { - Some(filter) => and(filter.as_ref().clone(), condition.clone()), + Some(filter) => { + and(filter.as_ref().clone(), condition.clone()) + } None => condition.clone(), }; let mut func = func.clone(); func.params.filter = Some(Box::new(filter)); Ok(Transformed::yes(Expr::AggregateFunction(func))) } - _ => { - Ok(Transformed::no(nested_expr)) - } - } - })?.data; + _ => Ok(Transformed::no(nested_expr)), + })? + .data; let expr = match default_on_null.as_ref() { Some(default_values) => { - when(expr.clone().is_null(), default_values[i].clone()).otherwise(expr)? - }, - None => expr + when(expr.clone().is_null(), default_values[i].clone()) + .otherwise(expr)? + } + None => expr, }; - let pivot_col_name = format!("{}_{}", value_alias.replace("\"", "").replace("'", ""), name); - aggr_exprs.push(expr.alias_with_metadata(pivot_col_name, metadata.clone())); + let pivot_col_name = format!( + "{}_{}", + value_alias.replace("\"", "").replace("'", ""), + name + ); + aggr_exprs + .push(expr.alias_with_metadata(pivot_col_name, metadata.clone())); } } - + let aggregate_plan = self.aggregate(group_by_columns, aggr_exprs)?; - + Ok(aggregate_plan) } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index a980b6f6f549..0c4f3f861682 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -22,12 +22,15 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow::array::Array; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, ScalarValue, Span, Spans, TableReference + not_impl_err, plan_err, DFSchema, Diagnostic, Result, ScalarValue, Span, Spans, + TableReference, }; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; -use sqlparser::ast::{FunctionArg, FunctionArgExpr, PivotValueSource, Spanned, TableFactor}; +use sqlparser::ast::{ + Expr as SqlExpr, FunctionArg, FunctionArgExpr, PivotValueSource, Spanned, TableFactor, +}; mod join; @@ -184,26 +187,36 @@ impl SqlToRel<'_, S> { .build()?; (plan, alias) } - TableFactor::Pivot { table, aggregate_functions, value_column, value_source, default_on_null, alias } => { - let plan = self.create_relation(table.as_ref().clone(), planner_context)?; + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + default_on_null, + alias, + } => { + let plan = + self.create_relation(table.as_ref().clone(), planner_context)?; let schema = plan.schema(); let aggregate_functions = aggregate_functions .into_iter() .map(|func| { - self.sql_expr_to_logical_expr( - func.expr, - schema, - planner_context, - ).map(|expr| - match func.alias { + self.sql_expr_to_logical_expr(func.expr, schema, planner_context) + .map(|expr| match func.alias { Some(name) => expr.alias(name.value), None => expr, - } - ) + }) }) .collect::>>()?; - let value_column = value_column.into_iter().map(|id| { - match self.sql_identifier_to_expr(id, schema, planner_context)? { + let value_column = value_column.into_iter().map(|column| { + let expr = match column { + SqlExpr::Identifier(id) => self.sql_identifier_to_expr(id, schema, planner_context)?, + SqlExpr::CompoundIdentifier(idents) => self.sql_compound_identifier_to_expr(idents, schema, planner_context)?, + expr => return plan_err!( + "Expected column identifier, found: {expr:?} in pivot value column" + ), + }; + match expr { Expr::Column(col) => Ok(col), expr => plan_err!( "Expected column identifier, found: {expr:?} in pivot value column" @@ -215,41 +228,49 @@ impl SqlToRel<'_, S> { // Dynamic pivot: the output schema is determined by the data in the source table at runtime. return plan_err!("Dynamic pivot is not supported yet"); }; - let value_source = source.into_iter().map(|expr_with_alias| { - self.sql_expr_to_logical_expr( + let value_source = source + .into_iter() + .map(|expr_with_alias| { + self.sql_expr_to_logical_expr( expr_with_alias.expr, schema, planner_context, - ).map(|expr| + ) + .map(|expr| { match expr_with_alias.alias { Some(name) => expr.alias(name.value), None => expr, } - ) - }).collect::>>()?; + }) + }) + .collect::>>()?; let default_on_null = default_on_null .map(|expr| { - let expr = self.sql_expr_to_logical_expr( - expr, - schema, - planner_context, - )?; + let expr = + self.sql_expr_to_logical_expr(expr, schema, planner_context)?; match expr { - Expr::Literal(ScalarValue::List(list), _) => { - (0..list.len()).map(|idx| Ok(Expr::Literal(ScalarValue::try_from_array(list.values(), idx)?, None))).collect::>>() - } + Expr::Literal(ScalarValue::List(list), _) => (0..list.len()) + .map(|idx| { + Ok(Expr::Literal( + ScalarValue::try_from_array(list.values(), idx)?, + None, + )) + }) + .collect::>>(), _ => return plan_err!("Pivot default value cannot be NULL"), } }) .transpose()?; - - let plan = LogicalPlanBuilder::from(plan).pivot( - aggregate_functions, - value_column, - value_source, - default_on_null, - )?.build()?; + + let plan = LogicalPlanBuilder::from(plan) + .pivot( + aggregate_functions, + value_column, + value_source, + default_on_null, + )? + .build()?; (plan, alias) } // @todo Support TableFactory::TableFunction? diff --git a/datafusion/sqllogictest/test_files/pivot.slt b/datafusion/sqllogictest/test_files/pivot.slt new file mode 100644 index 000000000000..5e42ff2d04ce --- /dev/null +++ b/datafusion/sqllogictest/test_files/pivot.slt @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# PIVOT tests + +# Setup test table +statement ok +CREATE TABLE person (id INT, name STRING, age INT, class INT, address STRING) AS VALUES + (100, 'John', 30, 1, 'Street 1'), + (200, 'Mary', NULL, 1, 'Street 2'), + (300, 'Mike', 80, 3, 'Street 3'), + (400, 'Dan', 50, 4, 'Street 4'); + +# Test basic PIVOT with single column +query ITIRIR rowsort +SELECT * FROM person + PIVOT ( + SUM(age) AS a, AVG(class) AS c + FOR name IN ('John' AS john, 'Mike' AS mike) + ); +---- +100 Street 1 30 1 NULL NULL +200 Street 2 NULL NULL NULL NULL +300 Street 3 NULL NULL 80 3 +400 Street 4 NULL NULL NULL NULL + +# Test PIVOT with multiple columns in FOR clause +query ITIRIR rowsort +SELECT * FROM person + PIVOT ( + SUM(age) AS a, AVG(class) AS c + FOR (name, age) IN (('John', 30) AS c1, ('Mike', 40) AS c2) + ); +---- +100 Street 1 30 1 NULL NULL +200 Street 2 NULL NULL NULL NULL +300 Street 3 NULL NULL NULL NULL +400 Street 4 NULL NULL NULL NULL + +# Cleanup +statement ok +DROP TABLE person; From 0b95a56b09984db0e813ff02e2e0eef19b822c0c Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 6 Oct 2025 17:16:48 +0800 Subject: [PATCH 5/6] feat: pivot & unpivot --- datafusion/core/src/dataframe/mod.rs | 74 ++++ datafusion/core/tests/dataframe/mod.rs | 330 ++++++++++++++++++ datafusion/expr/src/logical_plan/builder.rs | 177 +++++++++- datafusion/sql/src/relation/mod.rs | 118 ++++++- .../sqllogictest/test_files/unpivot.slt | 83 +++++ 5 files changed, 771 insertions(+), 11 deletions(-) create mode 100644 datafusion/sqllogictest/test_files/unpivot.slt diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index c4cc20973980..0c12610aca96 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2423,6 +2423,80 @@ impl DataFrame { projection_requires_validation: self.projection_requires_validation, }) } + + /// Unpivot the DataFrame, transforming columns into rows. + /// + /// # Arguments + /// * `value_column_names` - Names for the value columns in the output + /// * `name_column` - Name for the column that will contain the original column names + /// * `unpivot_columns` - List of (column_names, optional_alias) tuples to unpivot + /// * `id_columns` - Optional list of columns to preserve (if None, all non-unpivoted columns are preserved) + /// * `include_nulls` - Whether to include rows with NULL values (default: false excludes NULLs) + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// // Assume we have a DataFrame with columns: id, jan, feb, mar + /// // We want to unpivot jan, feb, mar into month/value columns + /// # Ok(()) + /// # } + /// ``` + pub fn unpivot( + self, + value_column_names: Vec, + name_column: String, + unpivot_columns: Vec<(Vec, Option)>, + id_columns: Option>, + include_nulls: bool, + ) -> Result { + // Get required UDF functions from the session state + let named_struct_fn = self + .session_state + .scalar_functions() + .get("named_struct") + .ok_or_else(|| { + DataFusionError::Plan("named_struct function not found".to_string()) + })?; + + let make_array_fn = self + .session_state + .scalar_functions() + .get("make_array") + .ok_or_else(|| { + DataFusionError::Plan("make_array function not found".to_string()) + })?; + + let get_field_fn = self + .session_state + .scalar_functions() + .get("get_field") + .ok_or_else(|| { + DataFusionError::Plan("get_field function not found".to_string()) + })?; + + let plan = LogicalPlanBuilder::from(self.plan) + .unpivot( + value_column_names, + name_column, + unpivot_columns, + id_columns, + include_nulls, + named_struct_fn, + make_array_fn, + get_field_fn, + )? + .build()?; + + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: true, + }) + } } /// Macro for creating DataFrame. diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index e9b531723fd6..288a26a085e5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6479,3 +6479,333 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_pivot() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a test table for pivot + let schema = Arc::new(Schema::new(vec![ + Field::new("category", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A", "B", "A", "B", "C"])), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + ], + )?; + + ctx.register_batch("pivot_test", batch)?; + let df = ctx.table("pivot_test").await?; + + // Pivot the DataFrame so each unique category becomes a column + let pivoted = df.pivot( + vec![sum(col("value"))], + vec![datafusion_common::Column::from_name("category")], + vec![lit("A"), lit("B"), lit("C")], + None, + )?; + + let results = pivoted.collect().await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +--------------------+--------------------+--------------------+ + | Utf8(A)_sum(value) | Utf8(B)_sum(value) | Utf8(C)_sum(value) | + +--------------------+--------------------+--------------------+ + | 4 | 6 | 5 | + +--------------------+--------------------+--------------------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_pivot_with_default() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("category", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 2, 2])), + Arc::new(StringArray::from(vec!["A", "B", "A", "C"])), + Arc::new(Int32Array::from(vec![10, 20, 30, 40])), + ], + )?; + + ctx.register_batch("pivot_default_test", batch)?; + let df = ctx.table("pivot_default_test").await?; + + // Pivot with default values + let pivoted = df.pivot( + vec![sum(col("value"))], + vec![datafusion_common::Column::from_name("category")], + vec![lit("A"), lit("B"), lit("C")], + Some(vec![lit(0)]), + )?; + + let results = pivoted.collect().await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+--------------------+--------------------+--------------------+ + | id | Utf8(A)_sum(value) | Utf8(B)_sum(value) | Utf8(C)_sum(value) | + +----+--------------------+--------------------+--------------------+ + | 1 | 10 | 20 | 0 | + | 2 | 30 | 0 | 40 | + +----+--------------------+--------------------+--------------------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_unpivot() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a test table for unpivot + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("jan", DataType::Int32, true), + Field::new("feb", DataType::Int32, true), + Field::new("mar", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![100, 110])), + Arc::new(Int32Array::from(vec![200, 210])), + Arc::new(Int32Array::from(vec![300, 310])), + ], + )?; + + ctx.register_batch("unpivot_test", batch)?; + let df = ctx.table("unpivot_test").await?; + + // Unpivot jan, feb, mar into month/value columns + let unpivoted = df.unpivot( + vec!["value".to_string()], + "month".to_string(), + vec![ + (vec!["jan".to_string()], Some("jan".to_string())), + (vec!["feb".to_string()], Some("feb".to_string())), + (vec!["mar".to_string()], Some("mar".to_string())), + ], + Some(vec!["id".to_string()]), + false, + )?; + + let results = unpivoted + .sort(vec![ + col("id").sort(true, true), + col("month").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+-------+-------+ + | id | month | value | + +----+-------+-------+ + | 1 | feb | 200 | + | 1 | jan | 100 | + | 1 | mar | 300 | + | 2 | feb | 210 | + | 2 | jan | 110 | + | 2 | mar | 310 | + +----+-------+-------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_unpivot_with_nulls() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("col1", DataType::Int32, true), + Field::new("col2", DataType::Int32, true), + Field::new("col3", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![Some(10), None, Some(30)])), + Arc::new(Int32Array::from(vec![Some(20), Some(25), None])), + Arc::new(Int32Array::from(vec![None, Some(35), Some(40)])), + ], + )?; + + ctx.register_batch("unpivot_nulls_test", batch)?; + let df = ctx.table("unpivot_nulls_test").await?; + + // Test with include_nulls = false (default behavior) + let unpivoted = df.clone().unpivot( + vec!["value".to_string()], + "column".to_string(), + vec![ + (vec!["col1".to_string()], Some("col1".to_string())), + (vec!["col2".to_string()], Some("col2".to_string())), + (vec!["col3".to_string()], Some("col3".to_string())), + ], + Some(vec!["id".to_string()]), + false, // exclude nulls + )?; + + let results = unpivoted + .sort(vec![ + col("id").sort(true, true), + col("column").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+--------+-------+ + | id | column | value | + +----+--------+-------+ + | 1 | col1 | 10 | + | 1 | col2 | 20 | + | 2 | col2 | 25 | + | 2 | col3 | 35 | + | 3 | col1 | 30 | + | 3 | col3 | 40 | + +----+--------+-------+ + "### + ); + + // Test with include_nulls = true + let unpivoted_with_nulls = df.unpivot( + vec!["value".to_string()], + "column".to_string(), + vec![ + (vec!["col1".to_string()], Some("col1".to_string())), + (vec!["col2".to_string()], Some("col2".to_string())), + (vec!["col3".to_string()], Some("col3".to_string())), + ], + Some(vec!["id".to_string()]), + true, // include nulls + )?; + + let results_with_nulls = unpivoted_with_nulls + .sort(vec![ + col("id").sort(true, true), + col("column").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results_with_nulls), + @r###" + +----+--------+-------+ + | id | column | value | + +----+--------+-------+ + | 1 | col1 | 10 | + | 1 | col2 | 20 | + | 1 | col3 | | + | 2 | col1 | | + | 2 | col2 | 25 | + | 2 | col3 | 35 | + | 3 | col1 | 30 | + | 3 | col2 | | + | 3 | col3 | 40 | + +----+--------+-------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_unpivot_multiple_values() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("q1_sales", DataType::Int32, true), + Field::new("q1_profit", DataType::Int32, true), + Field::new("q2_sales", DataType::Int32, true), + Field::new("q2_profit", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![100, 200])), + Arc::new(Int32Array::from(vec![10, 20])), + Arc::new(Int32Array::from(vec![150, 250])), + Arc::new(Int32Array::from(vec![15, 25])), + ], + )?; + + ctx.register_batch("unpivot_multi_test", batch)?; + let df = ctx.table("unpivot_multi_test").await?; + + // Unpivot with multiple value columns + let unpivoted = df.unpivot( + vec!["sales".to_string(), "profit".to_string()], + "quarter".to_string(), + vec![ + ( + vec!["q1_sales".to_string(), "q1_profit".to_string()], + Some("Q1".to_string()), + ), + ( + vec!["q2_sales".to_string(), "q2_profit".to_string()], + Some("Q2".to_string()), + ), + ], + Some(vec!["id".to_string()]), + false, + )?; + + let results = unpivoted + .sort(vec![ + col("id").sort(true, true), + col("quarter").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+---------+-------+--------+ + | id | quarter | sales | profit | + +----+---------+-------+--------+ + | 1 | Q1 | 100 | 10 | + | 1 | Q2 | 150 | 15 | + | 2 | Q1 | 200 | 20 | + | 2 | Q2 | 250 | 25 | + +----+---------+-------+--------+ + "### + ); + + Ok(()) +} diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4beba715d92a..9a9befbff53d 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -47,8 +47,8 @@ use crate::utils::{ }; use crate::{ and, binary_expr, lit, when, DmlStatement, ExplainOption, Expr, ExprSchemable, - Operator, RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, - WriteOp, + Operator, RecursiveQuery, ScalarUDF, Statement, TableProviderFilterPushDown, + TableSource, WriteOp, }; use super::dml::InsertOp; @@ -1589,14 +1589,23 @@ impl LogicalPlanBuilder { }; for (i, agg_func) in aggregate_functions.iter().enumerate() { - let Expr::Alias(Alias { - expr, - name, - metadata, - .. - }) = agg_func - else { - return plan_err!("Aggregate function must has an alias"); + let (expr, name, metadata) = match agg_func { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) if matches!(expr.as_ref(), Expr::AggregateFunction(_)) => { + (expr.as_ref(), name, metadata) + } + Expr::AggregateFunction(_) => { + (agg_func, &agg_func.to_string(), &None) + } + _ => { + return plan_err!( + "Pivot aggregate function must be either an alias or an aggregate function expression, but got: {agg_func:?}" + ); + } }; let expr = expr .clone() @@ -1637,6 +1646,154 @@ impl LogicalPlanBuilder { Ok(aggregate_plan) } + + #[allow(clippy::too_many_arguments)] + pub fn unpivot( + self, + value_column_names: Vec, + name_column: String, + unpivot_columns: Vec<(Vec, Option)>, + id_columns: Option>, + include_nulls: bool, + named_struct_fn: &Arc, + make_array_fn: &Arc, + get_field_fn: &Arc, + ) -> Result { + let schema = self.schema(); + let num_value_columns = value_column_names.len(); + + // Validate that all unpivot columns have the same number of columns + for (cols, _) in &unpivot_columns { + if cols.len() != num_value_columns { + return plan_err!( + "All unpivot columns must have {} column(s), but found {}", + num_value_columns, + cols.len() + ); + } + } + + // Get the list of columns that should be preserved (not unpivoted) + let unpivot_col_set: HashSet = unpivot_columns + .iter() + .flat_map(|(cols, _)| cols.iter().cloned()) + .collect(); + + let preserved_columns: Vec = if let Some(id_columns) = id_columns { + id_columns + .iter() + .map(|col| Expr::Column(Column::from_name(col))) + .collect() + } else { + schema + .iter() + .filter_map(|(q, f)| { + if !unpivot_col_set.contains(f.name()) { + Some(Expr::Column(Column::new(q.cloned(), f.name()))) + } else { + None + } + }) + .collect() + }; + + // Build array of structs: array[struct(name_val, col1_val, col2_val, ...), ...] + let mut struct_exprs = Vec::new(); + + for (col_names, alias_opt) in unpivot_columns { + // Build struct fields: [name_literal, name_column_name, value1, value1_name, value2, value2_name, ...] + let mut struct_fields = Vec::new(); + + // Add name field + let name_value = alias_opt.unwrap_or_else(|| col_names[0].clone()); + struct_fields.push(lit(name_column.clone())); + struct_fields.push(lit(name_value)); + + // Add value fields + for (i, col_name) in col_names.iter().enumerate() { + struct_fields.push(lit(value_column_names[i].clone())); + struct_fields.push(Expr::Column(Column::from_qualified_name(col_name))); + } + + // Create struct expression + let struct_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(named_struct_fn), + struct_fields, + )); + + struct_exprs.push(struct_expr); + } + + let unpivot_array_column = "__unpivot_array"; + + // Create array expression + let array_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(make_array_fn), + struct_exprs, + )) + .alias(unpivot_array_column); + + // Project: preserved_columns + array + let mut projection = preserved_columns.clone(); + projection.push(array_expr); + + let plan = self.project(projection)?.build()?; + + // Unnest the array + let plan = LogicalPlanBuilder::from(plan) + .unnest_column(Column::from_name(unpivot_array_column))? + .build()?; + + // Extract fields from the unnested struct + let mut final_projection = preserved_columns.clone(); + final_projection.push( + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(get_field_fn), + vec![ + Expr::Column(Column::from_name(unpivot_array_column)), + lit(name_column.clone()), + ], + )) + .alias(&name_column), + ); + + // Add value columns from struct fields + for value_col_name in &value_column_names { + final_projection.push( + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(get_field_fn), + vec![ + Expr::Column(Column::from_name(unpivot_array_column)), + lit(value_col_name.clone()), + ], + )) + .alias(value_col_name), + ); + } + + let mut plan_builder = + LogicalPlanBuilder::from(plan).project(final_projection)?; + + // Add filter to exclude NULLs if needed + if !include_nulls { + // Create a condition that checks if any of the value columns is NOT NULL + let mut not_null_condition: Option = None; + for value_col_name in &value_column_names { + let is_not_null = + Expr::Column(Column::from_name(value_col_name)).is_not_null(); + not_null_condition = match not_null_condition { + None => Some(is_not_null), + Some(prev) => Some(prev.or(is_not_null)), + }; + } + + if let Some(condition) = not_null_condition { + plan_builder = plan_builder.filter(condition)?; + } + } + + Ok(plan_builder) + } } impl From for LogicalPlanBuilder { diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 0c4f3f861682..2a68bd9011fd 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -258,7 +258,7 @@ impl SqlToRel<'_, S> { )) }) .collect::>>(), - _ => return plan_err!("Pivot default value cannot be NULL"), + _ => plan_err!("Pivot default value cannot be NULL"), } }) .transpose()?; @@ -273,6 +273,122 @@ impl SqlToRel<'_, S> { .build()?; (plan, alias) } + TableFactor::Unpivot { + table, + value, + name, + columns, + null_inclusion, + alias, + } => { + let plan = + self.create_relation(table.as_ref().clone(), planner_context)?; + + // Parse value expression(s) + let value_columns = match value { + SqlExpr::Tuple(exprs) => exprs, + single_expr => vec![single_expr], + }; + let value_column_names: Vec = value_columns + .into_iter() + .map(|expr| match expr { + SqlExpr::Identifier(id) => Ok(id.value), + _ => plan_err!("Expected identifier in UNPIVOT value clause"), + }) + .collect::>>()?; + + // Parse name column + let name_column = name.value; + + // Parse columns to unpivot + let unpivot_columns: Vec<(Vec, Option)> = columns + .into_iter() + .map(|col_with_alias| { + let column_names = match &col_with_alias.expr { + SqlExpr::Tuple(exprs) => exprs + .iter() + .map(|e| match e { + SqlExpr::Identifier(id) => Ok(id.value.clone()), + SqlExpr::CompoundIdentifier(ids) => Ok(ids + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join(".")), + _ => plan_err!( + "Expected identifier in UNPIVOT IN clause" + ), + }) + .collect::>>()?, + SqlExpr::Identifier(id) => vec![id.value.clone()], + SqlExpr::CompoundIdentifier(ids) => { + vec![ids + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join(".")] + } + _ => { + return plan_err!( + "Expected identifier or tuple in UNPIVOT IN clause" + ) + } + }; + + let alias = col_with_alias.alias.map(|alias| alias.value); + Ok((column_names, alias)) + }) + .collect::>>()?; + + // Determine if nulls should be included (default is EXCLUDE) + let include_nulls = matches!( + null_inclusion, + Some(sqlparser::ast::NullInclusion::IncludeNulls) + ); + + // Get named_struct and make_array functions from context + let named_struct_fn = self + .context_provider + .get_function_meta("named_struct") + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "named_struct function not found".to_string(), + ) + })?; + + let make_array_fn = self + .context_provider + .get_function_meta("make_array") + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "make_array function not found".to_string(), + ) + })?; + + let get_field_fn = self + .context_provider + .get_function_meta("get_field") + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "get_field function not found".to_string(), + ) + })?; + + // Call the unpivot function from LogicalPlanBuilder + let plan = LogicalPlanBuilder::from(plan) + .unpivot( + value_column_names, + name_column, + unpivot_columns, + None, // id_columns - will be inferred from schema + include_nulls, + &named_struct_fn, + &make_array_fn, + &get_field_fn, + )? + .build()?; + + (plan, alias) + } // @todo Support TableFactory::TableFunction? _ => { return not_impl_err!( diff --git a/datafusion/sqllogictest/test_files/unpivot.slt b/datafusion/sqllogictest/test_files/unpivot.slt new file mode 100644 index 000000000000..2d090517643e --- /dev/null +++ b/datafusion/sqllogictest/test_files/unpivot.slt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# UNPIVOT tests + +# Setup test table +statement ok +CREATE TABLE sales_quarterly (year INT, q1 INT, q2 INT, q3 INT, q4 INT) AS VALUES + (2020, null, 1000, 2000, 2500), + (2021, 2250, 3200, 4200, 5900), + (2022, 4200, 3100, null, null); + +# Test basic UNPIVOT (NULL values are excluded by default) +query ITI rowsort +SELECT * FROM sales_quarterly + UNPIVOT ( + sales FOR quarter IN (q1, q2, q3, q4) + ); +---- +2020 q2 1000 +2020 q3 2000 +2020 q4 2500 +2021 q1 2250 +2021 q2 3200 +2021 q3 4200 +2021 q4 5900 +2022 q1 4200 +2022 q2 3100 + +# Test UNPIVOT with INCLUDE NULLS +query ITI rowsort +SELECT up.* FROM sales_quarterly + UNPIVOT INCLUDE NULLS ( + sales FOR quarter IN (q1 AS Q1, q2 AS Q2, q3 AS Q3, q4 AS Q4) + ) AS up; +---- +2020 Q1 NULL +2020 Q2 1000 +2020 Q3 2000 +2020 Q4 2500 +2021 Q1 2250 +2021 Q2 3200 +2021 Q3 4200 +2021 Q4 5900 +2022 Q1 4200 +2022 Q2 3100 +2022 Q3 NULL +2022 Q4 NULL + +# Test UNPIVOT with multiple value columns +query ITII rowsort +SELECT * FROM sales_quarterly + UNPIVOT EXCLUDE NULLS ( + (first_quarter, second_quarter) + FOR half_of_the_year IN ( + (q1, q2) AS H1, + (q3, q4) AS H2 + ) + ); +---- +2020 H1 NULL 1000 +2020 H2 2000 2500 +2021 H1 2250 3200 +2021 H2 4200 5900 +2022 H1 4200 3100 + +# Cleanup +statement ok +DROP TABLE sales_quarterly; From 1f471f7cc3aba0c0292e2b01d7c526918a489c3f Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Tue, 7 Oct 2025 08:14:12 +0800 Subject: [PATCH 6/6] update doc --- datafusion/core/src/dataframe/mod.rs | 36 +++++++++++++++------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 0e1de04b5825..d6f29c99fb00 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2380,24 +2380,16 @@ impl DataFrame { /// # Example /// ``` /// # use datafusion::prelude::*; - /// # use datafusion::arrow::array::{Int32Array, StringArray}; + /// # use arrow::array::{ArrayRef, Int32Array, StringArray}; + /// # use datafusion::functions_aggregate::expr_fn::sum; /// # use std::sync::Arc; /// # let ctx = SessionContext::new(); - /// # let batch = RecordBatch::try_new( - /// # Arc::new(Schema::new(vec![ - /// # Field::new("category", DataType::Utf8, false), - /// # Field::new("value", DataType::Int32, false), - /// # ])), - /// # vec![ - /// # Arc::new(StringArray::from(vec!["A", "B", "A"])), - /// # Arc::new(Int32Array::from(vec![1, 2, 3])) - /// # ] - /// # ).unwrap(); - /// # let df = ctx.read_batch(batch).unwrap(); - /// // Pivot the DataFrame so each unique category becomes a column + /// let value: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + /// let category: ArrayRef = Arc::new(StringArray::from(vec!["A", "B", "A"])); + /// let df = DataFrame::from_columns(vec![("value", value), ("category", category)]).unwrap(); /// let pivoted = df.pivot( /// vec![sum(col("value"))], - /// vec![col("category")], + /// vec![Column::from("category")], /// vec![col("value")], /// None /// ).unwrap(); @@ -2435,13 +2427,25 @@ impl DataFrame { /// /// # Example /// ``` + /// # use std::sync::Arc; + /// # use arrow::array::{ArrayRef, Int32Array}; /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// // Assume we have a DataFrame with columns: id, jan, feb, mar - /// // We want to unpivot jan, feb, mar into month/value columns + /// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + /// let jan: ArrayRef = Arc::new(Int32Array::from(vec![100, 110])); + /// let feb: ArrayRef = Arc::new(Int32Array::from(vec![200, 210])); + /// let mar: ArrayRef = Arc::new(Int32Array::from(vec![300, 310])); + /// let df = DataFrame::from_columns(vec![("id", id), ("jan", jan), ("feb", feb), ("mar", mar)]).unwrap(); + /// let unpivoted = df.unpivot( + /// vec!["jan".to_string(), "feb".to_string(), "mar".to_string()], + /// "month".to_string(), + /// vec![(vec!["jan".to_string(), "feb".to_string(), "mar".to_string()], None)], + /// None, + /// false + /// ).unwrap(); /// # Ok(()) /// # } /// ```