diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 287a133273d8..d6f29c99fb00 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -2368,6 +2368,139 @@ impl DataFrame { let df = ctx.read_batch(batch)?; Ok(df) } + + /// Pivot the DataFrame, transforming rows into columns based on the specified value columns and aggregation functions. + /// + /// # Arguments + /// * `aggregate_functions` - Aggregation expressions to apply (e.g., sum, count). + /// * `value_column` - Columns whose unique values will become new columns in the output. + /// * `value_source` - Columns to use as values for the pivoted columns. + /// * `default_on_null` - Optional expressions to use as default values when a pivoted value is null. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use arrow::array::{ArrayRef, Int32Array, StringArray}; + /// # use datafusion::functions_aggregate::expr_fn::sum; + /// # use std::sync::Arc; + /// # let ctx = SessionContext::new(); + /// let value: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); + /// let category: ArrayRef = Arc::new(StringArray::from(vec!["A", "B", "A"])); + /// let df = DataFrame::from_columns(vec![("value", value), ("category", category)]).unwrap(); + /// let pivoted = df.pivot( + /// vec![sum(col("value"))], + /// vec![Column::from("category")], + /// vec![col("value")], + /// None + /// ).unwrap(); + /// ``` + pub fn pivot( + self, + aggregate_functions: Vec, + value_column: Vec, + value_source: Vec, + default_on_null: Option>, + ) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .pivot( + aggregate_functions, + value_column, + value_source, + default_on_null, + )? + .build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: self.projection_requires_validation, + }) + } + + /// Unpivot the DataFrame, transforming columns into rows. + /// + /// # Arguments + /// * `value_column_names` - Names for the value columns in the output + /// * `name_column` - Name for the column that will contain the original column names + /// * `unpivot_columns` - List of (column_names, optional_alias) tuples to unpivot + /// * `id_columns` - Optional list of columns to preserve (if None, all non-unpivoted columns are preserved) + /// * `include_nulls` - Whether to include rows with NULL values (default: false excludes NULLs) + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use arrow::array::{ArrayRef, Int32Array}; + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + /// let jan: ArrayRef = Arc::new(Int32Array::from(vec![100, 110])); + /// let feb: ArrayRef = Arc::new(Int32Array::from(vec![200, 210])); + /// let mar: ArrayRef = Arc::new(Int32Array::from(vec![300, 310])); + /// let df = DataFrame::from_columns(vec![("id", id), ("jan", jan), ("feb", feb), ("mar", mar)]).unwrap(); + /// let unpivoted = df.unpivot( + /// vec!["jan".to_string(), "feb".to_string(), "mar".to_string()], + /// "month".to_string(), + /// vec![(vec!["jan".to_string(), "feb".to_string(), "mar".to_string()], None)], + /// None, + /// false + /// ).unwrap(); + /// # Ok(()) + /// # } + /// ``` + pub fn unpivot( + self, + value_column_names: Vec, + name_column: String, + unpivot_columns: Vec<(Vec, Option)>, + id_columns: Option>, + include_nulls: bool, + ) -> Result { + // Get required UDF functions from the session state + let named_struct_fn = self + .session_state + .scalar_functions() + .get("named_struct") + .ok_or_else(|| { + DataFusionError::Plan("named_struct function not found".to_string()) + })?; + + let make_array_fn = self + .session_state + .scalar_functions() + .get("make_array") + .ok_or_else(|| { + DataFusionError::Plan("make_array function not found".to_string()) + })?; + + let get_field_fn = self + .session_state + .scalar_functions() + .get("get_field") + .ok_or_else(|| { + DataFusionError::Plan("get_field function not found".to_string()) + })?; + + let plan = LogicalPlanBuilder::from(self.plan) + .unpivot( + value_column_names, + name_column, + unpivot_columns, + id_columns, + include_nulls, + named_struct_fn, + make_array_fn, + get_field_fn, + )? + .build()?; + + Ok(DataFrame { + session_state: self.session_state, + plan, + projection_requires_validation: true, + }) + } } /// Macro for creating DataFrame. diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index aa538f6dee81..3330093ae451 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6479,3 +6479,333 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_pivot() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a test table for pivot + let schema = Arc::new(Schema::new(vec![ + Field::new("category", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from(vec!["A", "B", "A", "B", "C"])), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])), + ], + )?; + + ctx.register_batch("pivot_test", batch)?; + let df = ctx.table("pivot_test").await?; + + // Pivot the DataFrame so each unique category becomes a column + let pivoted = df.pivot( + vec![sum(col("value"))], + vec![datafusion_common::Column::from_name("category")], + vec![lit("A"), lit("B"), lit("C")], + None, + )?; + + let results = pivoted.collect().await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +--------------------+--------------------+--------------------+ + | Utf8(A)_sum(value) | Utf8(B)_sum(value) | Utf8(C)_sum(value) | + +--------------------+--------------------+--------------------+ + | 4 | 6 | 5 | + +--------------------+--------------------+--------------------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_pivot_with_default() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("category", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1, 2, 2])), + Arc::new(StringArray::from(vec!["A", "B", "A", "C"])), + Arc::new(Int32Array::from(vec![10, 20, 30, 40])), + ], + )?; + + ctx.register_batch("pivot_default_test", batch)?; + let df = ctx.table("pivot_default_test").await?; + + // Pivot with default values + let pivoted = df.pivot( + vec![sum(col("value"))], + vec![datafusion_common::Column::from_name("category")], + vec![lit("A"), lit("B"), lit("C")], + Some(vec![lit(0)]), + )?; + + let results = pivoted.collect().await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+--------------------+--------------------+--------------------+ + | id | Utf8(A)_sum(value) | Utf8(B)_sum(value) | Utf8(C)_sum(value) | + +----+--------------------+--------------------+--------------------+ + | 1 | 10 | 20 | 0 | + | 2 | 30 | 0 | 40 | + +----+--------------------+--------------------+--------------------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_unpivot() -> Result<()> { + let ctx = SessionContext::new(); + + // Create a test table for unpivot + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("jan", DataType::Int32, true), + Field::new("feb", DataType::Int32, true), + Field::new("mar", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![100, 110])), + Arc::new(Int32Array::from(vec![200, 210])), + Arc::new(Int32Array::from(vec![300, 310])), + ], + )?; + + ctx.register_batch("unpivot_test", batch)?; + let df = ctx.table("unpivot_test").await?; + + // Unpivot jan, feb, mar into month/value columns + let unpivoted = df.unpivot( + vec!["value".to_string()], + "month".to_string(), + vec![ + (vec!["jan".to_string()], Some("jan".to_string())), + (vec!["feb".to_string()], Some("feb".to_string())), + (vec!["mar".to_string()], Some("mar".to_string())), + ], + Some(vec!["id".to_string()]), + false, + )?; + + let results = unpivoted + .sort(vec![ + col("id").sort(true, true), + col("month").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+-------+-------+ + | id | month | value | + +----+-------+-------+ + | 1 | feb | 200 | + | 1 | jan | 100 | + | 1 | mar | 300 | + | 2 | feb | 210 | + | 2 | jan | 110 | + | 2 | mar | 310 | + +----+-------+-------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_unpivot_with_nulls() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("col1", DataType::Int32, true), + Field::new("col2", DataType::Int32, true), + Field::new("col3", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![Some(10), None, Some(30)])), + Arc::new(Int32Array::from(vec![Some(20), Some(25), None])), + Arc::new(Int32Array::from(vec![None, Some(35), Some(40)])), + ], + )?; + + ctx.register_batch("unpivot_nulls_test", batch)?; + let df = ctx.table("unpivot_nulls_test").await?; + + // Test with include_nulls = false (default behavior) + let unpivoted = df.clone().unpivot( + vec!["value".to_string()], + "column".to_string(), + vec![ + (vec!["col1".to_string()], Some("col1".to_string())), + (vec!["col2".to_string()], Some("col2".to_string())), + (vec!["col3".to_string()], Some("col3".to_string())), + ], + Some(vec!["id".to_string()]), + false, // exclude nulls + )?; + + let results = unpivoted + .sort(vec![ + col("id").sort(true, true), + col("column").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+--------+-------+ + | id | column | value | + +----+--------+-------+ + | 1 | col1 | 10 | + | 1 | col2 | 20 | + | 2 | col2 | 25 | + | 2 | col3 | 35 | + | 3 | col1 | 30 | + | 3 | col3 | 40 | + +----+--------+-------+ + "### + ); + + // Test with include_nulls = true + let unpivoted_with_nulls = df.unpivot( + vec!["value".to_string()], + "column".to_string(), + vec![ + (vec!["col1".to_string()], Some("col1".to_string())), + (vec!["col2".to_string()], Some("col2".to_string())), + (vec!["col3".to_string()], Some("col3".to_string())), + ], + Some(vec!["id".to_string()]), + true, // include nulls + )?; + + let results_with_nulls = unpivoted_with_nulls + .sort(vec![ + col("id").sort(true, true), + col("column").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results_with_nulls), + @r###" + +----+--------+-------+ + | id | column | value | + +----+--------+-------+ + | 1 | col1 | 10 | + | 1 | col2 | 20 | + | 1 | col3 | | + | 2 | col1 | | + | 2 | col2 | 25 | + | 2 | col3 | 35 | + | 3 | col1 | 30 | + | 3 | col2 | | + | 3 | col3 | 40 | + +----+--------+-------+ + "### + ); + + Ok(()) +} + +#[tokio::test] +async fn test_dataframe_unpivot_multiple_values() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("q1_sales", DataType::Int32, true), + Field::new("q1_profit", DataType::Int32, true), + Field::new("q2_sales", DataType::Int32, true), + Field::new("q2_profit", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int32Array::from(vec![100, 200])), + Arc::new(Int32Array::from(vec![10, 20])), + Arc::new(Int32Array::from(vec![150, 250])), + Arc::new(Int32Array::from(vec![15, 25])), + ], + )?; + + ctx.register_batch("unpivot_multi_test", batch)?; + let df = ctx.table("unpivot_multi_test").await?; + + // Unpivot with multiple value columns + let unpivoted = df.unpivot( + vec!["sales".to_string(), "profit".to_string()], + "quarter".to_string(), + vec![ + ( + vec!["q1_sales".to_string(), "q1_profit".to_string()], + Some("Q1".to_string()), + ), + ( + vec!["q2_sales".to_string(), "q2_profit".to_string()], + Some("Q2".to_string()), + ), + ], + Some(vec!["id".to_string()]), + false, + )?; + + let results = unpivoted + .sort(vec![ + col("id").sort(true, true), + col("quarter").sort(true, true), + ])? + .collect() + .await?; + + assert_snapshot!( + batches_to_sort_string(&results), + @r###" + +----+---------+-------+--------+ + | id | quarter | sales | profit | + +----+---------+-------+--------+ + | 1 | Q1 | 100 | 10 | + | 1 | Q2 | 150 | 15 | + | 2 | Q1 | 200 | 20 | + | 2 | Q2 | 250 | 25 | + +----+---------+-------+--------+ + "### + ); + + Ok(()) +} diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 7a283b0420d3..6642ad15c67d 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -25,7 +25,9 @@ use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::{Alias, FieldMetadata, PlannedReplaceSelectItem, Sort as SortExpr}; +use crate::expr::{ + Alias, FieldMetadata, PlannedReplaceSelectItem, ScalarFunction, Sort as SortExpr, +}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, @@ -44,8 +46,9 @@ use crate::utils::{ group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, lit, DmlStatement, ExplainOption, Expr, ExprSchemable, Operator, - RecursiveQuery, Statement, TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, lit, when, DmlStatement, ExplainOption, Expr, ExprSchemable, + Operator, RecursiveQuery, ScalarUDF, Statement, TableProviderFilterPushDown, + TableSource, WriteOp, }; use super::dml::InsertOp; @@ -53,6 +56,7 @@ use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, @@ -1493,6 +1497,303 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + pub fn pivot( + self, + aggregate_functions: Vec, + value_column: Vec, + value_source: Vec, + default_on_null: Option>, + ) -> Result { + match default_on_null { + Some(default_values) if default_values.len() != aggregate_functions.len() => { + return plan_err!("Number of default values must match the number of aggregate functions"); + } + _ => {} + } + let mut used_columns = HashSet::new(); + used_columns.extend(value_column.iter().cloned()); + for agg in aggregate_functions.iter() { + expr_to_columns(agg, &mut used_columns)?; + } + + let used_columns = used_columns + .iter() + .map(|c| c.name.clone()) + .collect::>(); + + // Extract group by columns (all columns not involved in aggregation or pivot) + let schema = self.schema(); + + let group_by_columns = schema + .fields() + .iter() + .filter_map(|f| { + if used_columns.contains(f.name()) { + None // Skip columns that are used in aggregation or pivot + } else { + Some(Expr::Column(Column::from_name(f.name()))) + } + }) + .collect::>(); + + // Create filtered aggregate expressions for each value in value_source + let mut aggr_exprs = Vec::new(); + + for value in &value_source { + let (value, value_alias) = + if let Expr::Alias(Alias { expr, name, .. }) = value { + (expr.as_ref(), name) + } else { + (value, &value.to_string()) + }; + let condition = match value_column.len() { + 0 => return plan_err!("Pivot requires at least one value column"), + 1 => binary_expr( + Expr::Column(value_column[0].clone()), + Operator::IsNotDistinctFrom, + value.clone(), + ), + _ => { + let Expr::ScalarFunction(ScalarFunction { func, args }) = value + else { + return plan_err!("Pivot value must be struct(literals) if multiple value columns are provided"); + }; + if func.name() != "struct" { + return plan_err!("Pivot value must be struct(literals) if multiple value columns are provided"); + } + if args.len() != value_column.len() { + return plan_err!( + "Pivot value list length must match value column count" + ); + } + let mut condition: Option = None; + for (idx, col) in value_column.iter().enumerate() { + let single_condition = binary_expr( + Expr::Column(col.clone()), + Operator::IsNotDistinctFrom, + args[idx].clone(), + ); + condition = match condition { + None => Some(single_condition), + Some(prev) => Some(and(prev, single_condition)), + }; + } + match condition { + None => { + return plan_err!("Pivot value condition cannot be empty") + } + Some(cond) => cond, + } + } + }; + + for (i, agg_func) in aggregate_functions.iter().enumerate() { + let (expr, name, metadata) = match agg_func { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) if matches!(expr.as_ref(), Expr::AggregateFunction(_)) => { + (expr.as_ref(), name, metadata) + } + Expr::AggregateFunction(_) => { + (agg_func, &agg_func.to_string(), &None) + } + _ => { + return plan_err!( + "Pivot aggregate function must be either an alias or an aggregate function expression, but got: {agg_func:?}" + ); + } + }; + let expr = expr + .clone() + .transform(|nested_expr| match &nested_expr { + Expr::AggregateFunction(func) => { + let filter = match &func.params.filter { + Some(filter) => { + and(filter.as_ref().clone(), condition.clone()) + } + None => condition.clone(), + }; + let mut func = func.clone(); + func.params.filter = Some(Box::new(filter)); + Ok(Transformed::yes(Expr::AggregateFunction(func))) + } + _ => Ok(Transformed::no(nested_expr)), + })? + .data; + + let expr = match default_on_null.as_ref() { + Some(default_values) => { + when(expr.clone().is_null(), default_values[i].clone()) + .otherwise(expr)? + } + None => expr, + }; + let pivot_col_name = format!( + "{}_{}", + value_alias.replace("\"", "").replace("'", ""), + name + ); + aggr_exprs + .push(expr.alias_with_metadata(pivot_col_name, metadata.clone())); + } + } + + let aggregate_plan = self.aggregate(group_by_columns, aggr_exprs)?; + + Ok(aggregate_plan) + } + + #[allow(clippy::too_many_arguments)] + pub fn unpivot( + self, + value_column_names: Vec, + name_column: String, + unpivot_columns: Vec<(Vec, Option)>, + id_columns: Option>, + include_nulls: bool, + named_struct_fn: &Arc, + make_array_fn: &Arc, + get_field_fn: &Arc, + ) -> Result { + let schema = self.schema(); + let num_value_columns = value_column_names.len(); + + // Validate that all unpivot columns have the same number of columns + for (cols, _) in &unpivot_columns { + if cols.len() != num_value_columns { + return plan_err!( + "All unpivot columns must have {} column(s), but found {}", + num_value_columns, + cols.len() + ); + } + } + + // Get the list of columns that should be preserved (not unpivoted) + let unpivot_col_set: HashSet = unpivot_columns + .iter() + .flat_map(|(cols, _)| cols.iter().cloned()) + .collect(); + + let preserved_columns: Vec = if let Some(id_columns) = id_columns { + id_columns + .iter() + .map(|col| Expr::Column(Column::from_name(col))) + .collect() + } else { + schema + .iter() + .filter_map(|(q, f)| { + if !unpivot_col_set.contains(f.name()) { + Some(Expr::Column(Column::new(q.cloned(), f.name()))) + } else { + None + } + }) + .collect() + }; + + // Build array of structs: array[struct(name_val, col1_val, col2_val, ...), ...] + let mut struct_exprs = Vec::new(); + + for (col_names, alias_opt) in unpivot_columns { + // Build struct fields: [name_literal, name_column_name, value1, value1_name, value2, value2_name, ...] + let mut struct_fields = Vec::new(); + + // Add name field + let name_value = alias_opt.unwrap_or_else(|| col_names[0].clone()); + struct_fields.push(lit(name_column.clone())); + struct_fields.push(lit(name_value)); + + // Add value fields + for (i, col_name) in col_names.iter().enumerate() { + struct_fields.push(lit(value_column_names[i].clone())); + struct_fields.push(Expr::Column(Column::from_qualified_name(col_name))); + } + + // Create struct expression + let struct_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(named_struct_fn), + struct_fields, + )); + + struct_exprs.push(struct_expr); + } + + let unpivot_array_column = "__unpivot_array"; + + // Create array expression + let array_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(make_array_fn), + struct_exprs, + )) + .alias(unpivot_array_column); + + // Project: preserved_columns + array + let mut projection = preserved_columns.clone(); + projection.push(array_expr); + + let plan = self.project(projection)?.build()?; + + // Unnest the array + let plan = LogicalPlanBuilder::from(plan) + .unnest_column(Column::from_name(unpivot_array_column))? + .build()?; + + // Extract fields from the unnested struct + let mut final_projection = preserved_columns.clone(); + final_projection.push( + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(get_field_fn), + vec![ + Expr::Column(Column::from_name(unpivot_array_column)), + lit(name_column.clone()), + ], + )) + .alias(&name_column), + ); + + // Add value columns from struct fields + for value_col_name in &value_column_names { + final_projection.push( + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::clone(get_field_fn), + vec![ + Expr::Column(Column::from_name(unpivot_array_column)), + lit(value_col_name.clone()), + ], + )) + .alias(value_col_name), + ); + } + + let mut plan_builder = + LogicalPlanBuilder::from(plan).project(final_projection)?; + + // Add filter to exclude NULLs if needed + if !include_nulls { + // Create a condition that checks if any of the value columns is NOT NULL + let mut not_null_condition: Option = None; + for value_col_name in &value_column_names { + let is_not_null = + Expr::Column(Column::from_name(value_col_name)).is_not_null(); + not_null_condition = match not_null_condition { + None => Some(is_not_null), + Some(prev) => Some(prev.or(is_not_null)), + }; + } + + if let Some(condition) = not_null_condition { + plan_builder = plan_builder.filter(condition)?; + } + } + + Ok(plan_builder) + } } impl From for LogicalPlanBuilder { diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 3c57d195ade6..973710fd94c0 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -29,7 +29,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_expr::UNNAMED_TABLE; impl SqlToRel<'_, S> { - pub(super) fn sql_identifier_to_expr( + pub(crate) fn sql_identifier_to_expr( &self, id: Ident, schema: &DFSchema, diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9dfa078701d3..2a68bd9011fd 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -19,14 +19,18 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::array::Array; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, + not_impl_err, plan_err, DFSchema, Diagnostic, Result, ScalarValue, Span, Spans, + TableReference, }; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::{Subquery, SubqueryAlias}; -use sqlparser::ast::{FunctionArg, FunctionArgExpr, Spanned, TableFactor}; +use sqlparser::ast::{ + Expr as SqlExpr, FunctionArg, FunctionArgExpr, PivotValueSource, Spanned, TableFactor, +}; mod join; @@ -183,6 +187,208 @@ impl SqlToRel<'_, S> { .build()?; (plan, alias) } + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + default_on_null, + alias, + } => { + let plan = + self.create_relation(table.as_ref().clone(), planner_context)?; + let schema = plan.schema(); + let aggregate_functions = aggregate_functions + .into_iter() + .map(|func| { + self.sql_expr_to_logical_expr(func.expr, schema, planner_context) + .map(|expr| match func.alias { + Some(name) => expr.alias(name.value), + None => expr, + }) + }) + .collect::>>()?; + let value_column = value_column.into_iter().map(|column| { + let expr = match column { + SqlExpr::Identifier(id) => self.sql_identifier_to_expr(id, schema, planner_context)?, + SqlExpr::CompoundIdentifier(idents) => self.sql_compound_identifier_to_expr(idents, schema, planner_context)?, + expr => return plan_err!( + "Expected column identifier, found: {expr:?} in pivot value column" + ), + }; + match expr { + Expr::Column(col) => Ok(col), + expr => plan_err!( + "Expected column identifier, found: {expr:?} in pivot value column" + ), + } + }).collect::>>()?; + + let PivotValueSource::List(source) = value_source else { + // Dynamic pivot: the output schema is determined by the data in the source table at runtime. + return plan_err!("Dynamic pivot is not supported yet"); + }; + let value_source = source + .into_iter() + .map(|expr_with_alias| { + self.sql_expr_to_logical_expr( + expr_with_alias.expr, + schema, + planner_context, + ) + .map(|expr| { + match expr_with_alias.alias { + Some(name) => expr.alias(name.value), + None => expr, + } + }) + }) + .collect::>>()?; + + let default_on_null = default_on_null + .map(|expr| { + let expr = + self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + match expr { + Expr::Literal(ScalarValue::List(list), _) => (0..list.len()) + .map(|idx| { + Ok(Expr::Literal( + ScalarValue::try_from_array(list.values(), idx)?, + None, + )) + }) + .collect::>>(), + _ => plan_err!("Pivot default value cannot be NULL"), + } + }) + .transpose()?; + + let plan = LogicalPlanBuilder::from(plan) + .pivot( + aggregate_functions, + value_column, + value_source, + default_on_null, + )? + .build()?; + (plan, alias) + } + TableFactor::Unpivot { + table, + value, + name, + columns, + null_inclusion, + alias, + } => { + let plan = + self.create_relation(table.as_ref().clone(), planner_context)?; + + // Parse value expression(s) + let value_columns = match value { + SqlExpr::Tuple(exprs) => exprs, + single_expr => vec![single_expr], + }; + let value_column_names: Vec = value_columns + .into_iter() + .map(|expr| match expr { + SqlExpr::Identifier(id) => Ok(id.value), + _ => plan_err!("Expected identifier in UNPIVOT value clause"), + }) + .collect::>>()?; + + // Parse name column + let name_column = name.value; + + // Parse columns to unpivot + let unpivot_columns: Vec<(Vec, Option)> = columns + .into_iter() + .map(|col_with_alias| { + let column_names = match &col_with_alias.expr { + SqlExpr::Tuple(exprs) => exprs + .iter() + .map(|e| match e { + SqlExpr::Identifier(id) => Ok(id.value.clone()), + SqlExpr::CompoundIdentifier(ids) => Ok(ids + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join(".")), + _ => plan_err!( + "Expected identifier in UNPIVOT IN clause" + ), + }) + .collect::>>()?, + SqlExpr::Identifier(id) => vec![id.value.clone()], + SqlExpr::CompoundIdentifier(ids) => { + vec![ids + .iter() + .map(|i| i.value.as_str()) + .collect::>() + .join(".")] + } + _ => { + return plan_err!( + "Expected identifier or tuple in UNPIVOT IN clause" + ) + } + }; + + let alias = col_with_alias.alias.map(|alias| alias.value); + Ok((column_names, alias)) + }) + .collect::>>()?; + + // Determine if nulls should be included (default is EXCLUDE) + let include_nulls = matches!( + null_inclusion, + Some(sqlparser::ast::NullInclusion::IncludeNulls) + ); + + // Get named_struct and make_array functions from context + let named_struct_fn = self + .context_provider + .get_function_meta("named_struct") + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "named_struct function not found".to_string(), + ) + })?; + + let make_array_fn = self + .context_provider + .get_function_meta("make_array") + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "make_array function not found".to_string(), + ) + })?; + + let get_field_fn = self + .context_provider + .get_function_meta("get_field") + .ok_or_else(|| { + datafusion_common::DataFusionError::Plan( + "get_field function not found".to_string(), + ) + })?; + + // Call the unpivot function from LogicalPlanBuilder + let plan = LogicalPlanBuilder::from(plan) + .unpivot( + value_column_names, + name_column, + unpivot_columns, + None, // id_columns - will be inferred from schema + include_nulls, + &named_struct_fn, + &make_array_fn, + &get_field_fn, + )? + .build()?; + + (plan, alias) + } // @todo Support TableFactory::TableFunction? _ => { return not_impl_err!( diff --git a/datafusion/sqllogictest/test_files/pivot.slt b/datafusion/sqllogictest/test_files/pivot.slt new file mode 100644 index 000000000000..5e42ff2d04ce --- /dev/null +++ b/datafusion/sqllogictest/test_files/pivot.slt @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# PIVOT tests + +# Setup test table +statement ok +CREATE TABLE person (id INT, name STRING, age INT, class INT, address STRING) AS VALUES + (100, 'John', 30, 1, 'Street 1'), + (200, 'Mary', NULL, 1, 'Street 2'), + (300, 'Mike', 80, 3, 'Street 3'), + (400, 'Dan', 50, 4, 'Street 4'); + +# Test basic PIVOT with single column +query ITIRIR rowsort +SELECT * FROM person + PIVOT ( + SUM(age) AS a, AVG(class) AS c + FOR name IN ('John' AS john, 'Mike' AS mike) + ); +---- +100 Street 1 30 1 NULL NULL +200 Street 2 NULL NULL NULL NULL +300 Street 3 NULL NULL 80 3 +400 Street 4 NULL NULL NULL NULL + +# Test PIVOT with multiple columns in FOR clause +query ITIRIR rowsort +SELECT * FROM person + PIVOT ( + SUM(age) AS a, AVG(class) AS c + FOR (name, age) IN (('John', 30) AS c1, ('Mike', 40) AS c2) + ); +---- +100 Street 1 30 1 NULL NULL +200 Street 2 NULL NULL NULL NULL +300 Street 3 NULL NULL NULL NULL +400 Street 4 NULL NULL NULL NULL + +# Cleanup +statement ok +DROP TABLE person; diff --git a/datafusion/sqllogictest/test_files/unpivot.slt b/datafusion/sqllogictest/test_files/unpivot.slt new file mode 100644 index 000000000000..2d090517643e --- /dev/null +++ b/datafusion/sqllogictest/test_files/unpivot.slt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# UNPIVOT tests + +# Setup test table +statement ok +CREATE TABLE sales_quarterly (year INT, q1 INT, q2 INT, q3 INT, q4 INT) AS VALUES + (2020, null, 1000, 2000, 2500), + (2021, 2250, 3200, 4200, 5900), + (2022, 4200, 3100, null, null); + +# Test basic UNPIVOT (NULL values are excluded by default) +query ITI rowsort +SELECT * FROM sales_quarterly + UNPIVOT ( + sales FOR quarter IN (q1, q2, q3, q4) + ); +---- +2020 q2 1000 +2020 q3 2000 +2020 q4 2500 +2021 q1 2250 +2021 q2 3200 +2021 q3 4200 +2021 q4 5900 +2022 q1 4200 +2022 q2 3100 + +# Test UNPIVOT with INCLUDE NULLS +query ITI rowsort +SELECT up.* FROM sales_quarterly + UNPIVOT INCLUDE NULLS ( + sales FOR quarter IN (q1 AS Q1, q2 AS Q2, q3 AS Q3, q4 AS Q4) + ) AS up; +---- +2020 Q1 NULL +2020 Q2 1000 +2020 Q3 2000 +2020 Q4 2500 +2021 Q1 2250 +2021 Q2 3200 +2021 Q3 4200 +2021 Q4 5900 +2022 Q1 4200 +2022 Q2 3100 +2022 Q3 NULL +2022 Q4 NULL + +# Test UNPIVOT with multiple value columns +query ITII rowsort +SELECT * FROM sales_quarterly + UNPIVOT EXCLUDE NULLS ( + (first_quarter, second_quarter) + FOR half_of_the_year IN ( + (q1, q2) AS H1, + (q3, q4) AS H2 + ) + ); +---- +2020 H1 NULL 1000 +2020 H2 2000 2500 +2021 H1 2250 3200 +2021 H2 4200 5900 +2022 H1 4200 3100 + +# Cleanup +statement ok +DROP TABLE sales_quarterly;