Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2368,6 +2368,139 @@ impl DataFrame {
let df = ctx.read_batch(batch)?;
Ok(df)
}

/// Pivot the DataFrame, transforming rows into columns based on the specified value columns and aggregation functions.
///
/// # Arguments
/// * `aggregate_functions` - Aggregation expressions to apply (e.g., sum, count).
/// * `value_column` - Columns whose unique values will become new columns in the output.
/// * `value_source` - Columns to use as values for the pivoted columns.
/// * `default_on_null` - Optional expressions to use as default values when a pivoted value is null.
///
/// # Example
/// ```
/// # use datafusion::prelude::*;
/// # use arrow::array::{ArrayRef, Int32Array, StringArray};
/// # use datafusion::functions_aggregate::expr_fn::sum;
/// # use std::sync::Arc;
/// # let ctx = SessionContext::new();
/// let value: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3]));
/// let category: ArrayRef = Arc::new(StringArray::from(vec!["A", "B", "A"]));
/// let df = DataFrame::from_columns(vec![("value", value), ("category", category)]).unwrap();
/// let pivoted = df.pivot(
/// vec![sum(col("value"))],
/// vec![Column::from("category")],
/// vec![col("value")],
/// None
/// ).unwrap();
/// ```
pub fn pivot(
self,
aggregate_functions: Vec<Expr>,
value_column: Vec<Column>,
value_source: Vec<Expr>,
default_on_null: Option<Vec<Expr>>,
) -> Result<Self> {
let plan = LogicalPlanBuilder::from(self.plan)
.pivot(
aggregate_functions,
value_column,
value_source,
default_on_null,
)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
})
}

/// Unpivot the DataFrame, transforming columns into rows.
///
/// # Arguments
/// * `value_column_names` - Names for the value columns in the output
/// * `name_column` - Name for the column that will contain the original column names
/// * `unpivot_columns` - List of (column_names, optional_alias) tuples to unpivot
/// * `id_columns` - Optional list of columns to preserve (if None, all non-unpivoted columns are preserved)
/// * `include_nulls` - Whether to include rows with NULL values (default: false excludes NULLs)
///
/// # Example
/// ```
/// # use std::sync::Arc;
/// # use arrow::array::{ArrayRef, Int32Array};
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let id: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
/// let jan: ArrayRef = Arc::new(Int32Array::from(vec![100, 110]));
/// let feb: ArrayRef = Arc::new(Int32Array::from(vec![200, 210]));
/// let mar: ArrayRef = Arc::new(Int32Array::from(vec![300, 310]));
/// let df = DataFrame::from_columns(vec![("id", id), ("jan", jan), ("feb", feb), ("mar", mar)]).unwrap();
/// let unpivoted = df.unpivot(
/// vec!["jan".to_string(), "feb".to_string(), "mar".to_string()],
/// "month".to_string(),
/// vec![(vec!["jan".to_string(), "feb".to_string(), "mar".to_string()], None)],
/// None,
/// false
/// ).unwrap();
/// # Ok(())
/// # }
/// ```
pub fn unpivot(
self,
value_column_names: Vec<String>,
name_column: String,
unpivot_columns: Vec<(Vec<String>, Option<String>)>,
id_columns: Option<Vec<String>>,
include_nulls: bool,
) -> Result<Self> {
// Get required UDF functions from the session state
let named_struct_fn = self
.session_state
.scalar_functions()
.get("named_struct")
.ok_or_else(|| {
DataFusionError::Plan("named_struct function not found".to_string())
})?;

let make_array_fn = self
.session_state
.scalar_functions()
.get("make_array")
.ok_or_else(|| {
DataFusionError::Plan("make_array function not found".to_string())
})?;

let get_field_fn = self
.session_state
.scalar_functions()
.get("get_field")
.ok_or_else(|| {
DataFusionError::Plan("get_field function not found".to_string())
})?;

let plan = LogicalPlanBuilder::from(self.plan)
.unpivot(
value_column_names,
name_column,
unpivot_columns,
id_columns,
include_nulls,
named_struct_fn,
make_array_fn,
get_field_fn,
)?
.build()?;

Ok(DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: true,
})
}
}

/// Macro for creating DataFrame.
Expand Down
Loading