Skip to content
Merged
14 changes: 14 additions & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ pub enum UniqueKeepStrategy {
Any,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
#[strum(serialize_all = "snake_case")]
/// Naming strategy for the results of a pivot.
pub enum PivotColumnNaming {
/// Always combine the values and on-column names.
AlwaysCombine,
/// Prefix the values column name only if there is more than one values
/// column.
#[default]
Auto,
}

impl DataFrame {
pub fn materialized_column_iter(&self) -> impl ExactSizeIterator<Item = &Series> {
self.columns().iter().map(Column::as_materialized_series)
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub use parquet::*;
use polars_compute::rolling::QuantileMethod;
use polars_core::POOL;
use polars_core::error::feature_gated;
#[cfg(feature = "pivot")]
use polars_core::frame::PivotColumnNaming;
use polars_core::prelude::*;
use polars_io::RowIndex;
use polars_mem_engine::scan_predicate::functions::apply_scan_predicate_to_scan_ir;
Expand Down Expand Up @@ -1875,6 +1877,7 @@ impl LazyFrame {
agg: Expr,
maintain_order: bool,
separator: PlSmallStr,
column_naming: PivotColumnNaming,
) -> LazyFrame {
let opt_state = self.get_opt_state();
let lp = self
Expand All @@ -1887,6 +1890,7 @@ impl LazyFrame {
agg,
maintain_order,
separator,
column_naming,
)
.build();
Self::from_logical_plan(lp, opt_state)
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/dsl-schema-hashes.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"Dimension": "68880cdb10230df6c8c1632b073c80bd8ceb5c56a368c0cb438431ca9f3d3b31",
"DistinctOptionsDSL": "41be5ec69ef9a614f2b36ac5deadfecdea5cca847ae1ada9d4bc626ff52a5b38",
"DslFunction": "221f1a46a043c8ed54f57be981bf24509f04f5f91f0f08e0acc180d96f842ebf",
"DslPlan": "541a3874e573615ee2a40704c6c27fca5825a2411602c7ebc1db901c41f68e81",
"DslPlan": "14caf5b73e69c4975ff3a57331891521ff5b78c96bbaf8d6cc9be57c82f3ea98",
"Duration": "44999d59023085cbb592ce94b30d34f9b983081fc72bd6435a49bdf0869c0074",
"Duration2": "f251cb1bee2955a17c6defe1573bce21ddbe6cdf6eb9324a19cd37932ab29347",
"DynListLiteralValue": "2266a553cb4a943f7097f24539eaa802453cf8742675996215235bd682dec0e8",
Expand Down Expand Up @@ -112,6 +112,7 @@
"ParquetWriteOptions": "04196fdf5e136dc18278b5d0ef1054fa398a0a1a60147e40085ecd0180e89637",
"PartitionStrategy": "0e4535031aa9acf22fdf96ab10483f76e2f6ae6d5e5cd756be9adca490e0d05b",
"PartitionedSinkOptions": "bc7885b2bb87dc5fad4c5cf96c5a9b381403f8f9db6edefc81899dc9b9227934",
"PivotColumnNaming": "8bd7dfb879cf09ee95b306f20ceddbfe0d00320918b40378ebd5f3b6a0b4ea48",
"PlCredentialProvider": "5bbddd4f899afa592c318b20bb8d0bdfe2877fa5bf1a63d9cd0da908ac3aec0e",
"PlRefPath": "0faaddc3196c89bd9dcf872bbc4304471855dff7f9d24107ef279bc06ef7cbb4",
"PlanCallback": "5bbddd4f899afa592c318b20bb8d0bdfe2877fa5bf1a63d9cd0da908ac3aec0e",
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/builder_dsl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::sync::Arc;

#[cfg(feature = "pivot")]
use polars_core::frame::PivotColumnNaming;
use polars_core::prelude::*;
#[cfg(feature = "csv")]
use polars_io::csv::read::CsvReadOptions;
Expand Down Expand Up @@ -327,6 +329,7 @@ impl DslBuilder {
agg: Expr,
maintain_order: bool,
separator: PlSmallStr,
column_naming: PivotColumnNaming,
) -> Self {
DslPlan::Pivot {
input: Arc::new(self.0),
Expand All @@ -337,6 +340,7 @@ impl DslBuilder {
agg,
maintain_order,
separator,
column_naming,
}
.into()
}
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/dsl/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::fmt;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};

#[cfg(feature = "pivot")]
use polars_core::frame::PivotColumnNaming;
use polars_utils::arena::Node;
#[cfg(feature = "serde")]
use polars_utils::pl_serialize;
Expand Down Expand Up @@ -115,6 +117,7 @@ pub enum DslPlan {
agg: Expr,
maintain_order: bool,
separator: PlSmallStr,
column_naming: PivotColumnNaming,
},
/// Remove duplicates from the table
Distinct {
Expand Down Expand Up @@ -206,7 +209,7 @@ impl Clone for DslPlan {
Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
#[cfg(feature = "pivot")]
Self::Pivot { input, on, on_columns, index, values, agg, separator, maintain_order } => Self::Pivot { input: input.clone(), on: on.clone(), on_columns: on_columns.clone(), index: index.clone(), values: values.clone(), agg: agg.clone(), separator: separator.clone(), maintain_order: *maintain_order },
Self::Pivot { input, on, on_columns, index, values, agg, separator, maintain_order, column_naming } => Self::Pivot { input: input.clone(), on: on.clone(), on_columns: on_columns.clone(), index: index.clone(), values: values.clone(), agg: agg.clone(), separator: separator.clone(), maintain_order: *maintain_order, column_naming: *column_naming },
#[cfg(feature = "merge_sorted")]
Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-plan/src/dsl/serializable_plan.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "pivot")]
use polars_core::frame::PivotColumnNaming;
use polars_utils::unique_id::UniqueId;
use recursive::recursive;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -99,6 +101,7 @@ pub(crate) enum SerializableDslPlanNode {
agg: Expr,
maintain_order: bool,
separator: PlSmallStr,
column_naming: PivotColumnNaming,
},
Distinct {
input: DslPlanKey,
Expand Down Expand Up @@ -284,6 +287,7 @@ fn convert_dsl_plan_to_serializable_plan(
agg,
maintain_order,
separator,
column_naming,
} => SP::Pivot {
input: dsl_plan_key(input, arenas),
on: on.clone(),
Expand All @@ -293,6 +297,7 @@ fn convert_dsl_plan_to_serializable_plan(
agg: agg.clone(),
maintain_order: *maintain_order,
separator: separator.clone(),
column_naming: *column_naming,
},
DP::Distinct { input, options } => SP::Distinct {
input: dsl_plan_key(input, arenas),
Expand Down Expand Up @@ -529,6 +534,7 @@ fn try_convert_serializable_plan_to_dsl_plan(
agg,
maintain_order,
separator,
column_naming,
} => Ok(DP::Pivot {
input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,
on: on.clone(),
Expand All @@ -538,6 +544,7 @@ fn try_convert_serializable_plan_to_dsl_plan(
agg: agg.clone(),
maintain_order: *maintain_order,
separator: separator.clone(),
column_naming: *column_naming,
}),
SP::Distinct { input, options } => Ok(DP::Distinct {
input: get_dsl_plan(*input, ser_dsl_plan, arenas)?,
Expand Down
9 changes: 8 additions & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,10 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
agg,
maintain_order,
separator,
column_naming,
} => {
use polars_core::frame::PivotColumnNaming;

let input =
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(unique)))?;
let input_schema = ctxt.lp_arena.get(input).schema(ctxt.lp_arena);
Expand Down Expand Up @@ -882,7 +885,11 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult

for i in 0..on_columns.height() {
let mut name = String::new();
if values.len() > 1 {
let combine = match column_naming {
PivotColumnNaming::AlwaysCombine => true,
PivotColumnNaming::Auto => values.len() > 1,
};
if combine {
name.push_str(value.as_str());
name.push_str(separator.as_str());
}
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-python/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use std::hash::{Hash, Hasher};
pub use categorical::PyCategories;
#[cfg(feature = "object")]
use polars::chunked_array::object::PolarsObjectSafe;
#[cfg(feature = "pivot")]
use polars::frame::PivotColumnNaming;
use polars::frame::row::Row;
#[cfg(feature = "avro")]
use polars::io::avro::AvroCompression;
Expand Down Expand Up @@ -1265,6 +1267,24 @@ impl<'a, 'py> FromPyObject<'a, 'py> for Wrap<SearchSortedSide> {
}
}

#[cfg(feature = "pivot")]
impl<'a, 'py> FromPyObject<'a, 'py> for Wrap<PivotColumnNaming> {
type Error = PyErr;

fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
let parsed = match &*ob.extract::<PyBackedStr>()? {
"auto" => PivotColumnNaming::Auto,
"always_combine" => PivotColumnNaming::AlwaysCombine,
v => {
return Err(PyValueError::new_err(format!(
"`column_naming` must be one of {{'auto', 'combine'}}, got {v}",
)));
},
};
Ok(Wrap(parsed))
}
}

impl<'a, 'py> FromPyObject<'a, 'py> for Wrap<ClosedInterval> {
type Error = PyErr;

Expand Down
6 changes: 5 additions & 1 deletion crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::num::NonZeroUsize;
use arrow::ffi::export_iterator;
use either::Either;
use parking_lot::Mutex;
#[cfg(feature = "pivot")]
use polars::frame::PivotColumnNaming;
use polars::io::RowIndex;
use polars::time::*;
use polars_core::prelude::*;
Expand Down Expand Up @@ -1342,7 +1344,7 @@ impl PyLazyFrame {
}

#[cfg(feature = "pivot")]
#[pyo3(signature = (on, on_columns, index, values, agg, maintain_order, separator))]
#[pyo3(signature = (on, on_columns, index, values, agg, maintain_order, separator, column_naming))]
fn pivot(
&self,
on: PySelector,
Expand All @@ -1352,6 +1354,7 @@ impl PyLazyFrame {
agg: PyExpr,
maintain_order: bool,
separator: String,
column_naming: Wrap<PivotColumnNaming>,
) -> Self {
let ldf = self.ldf.read().clone();
ldf.pivot(
Expand All @@ -1362,6 +1365,7 @@ impl PyLazyFrame {
agg.inner,
maintain_order,
separator.into(),
column_naming.0,
)
.into()
}
Expand Down
3 changes: 3 additions & 0 deletions docs/source/src/rust/user-guide/transformations/pivot.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// --8<-- [start:setup]
use polars::frame::PivotColumnNaming;
use polars::prelude::*;
// --8<-- [end:setup]

Expand Down Expand Up @@ -36,6 +37,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}),
false,
"_".into(),
PivotColumnNaming::Auto,
)
.collect()?;
println!("{}", &out);
Expand Down Expand Up @@ -63,6 +65,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}),
false,
"_".into(),
PivotColumnNaming::Auto,
);
let out = q2.collect()?;
println!("{}", &out);
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/polars/_plr.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ class PyLazyFrame:
agg: PyExpr,
maintain_order: bool,
separator: str,
column_naming: Literal["auto", "always_combine"],
) -> PyLazyFrame: ...
def unpivot(
self,
Expand Down
9 changes: 9 additions & 0 deletions py-polars/src/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9438,6 +9438,7 @@ def pivot(
maintain_order: bool = True,
sort_columns: bool = False,
separator: str = "_",
column_naming: Literal["auto", "always_combine"] = "auto",
) -> DataFrame:
"""
Create a spreadsheet-style pivot table as a DataFrame.
Expand Down Expand Up @@ -9480,6 +9481,13 @@ def pivot(
separator
Used as separator/delimiter in generated column names in case of multiple
`values` columns.
column_naming : {'auto', 'always_combine'}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Says always_combine rather than combine.

How resulting column names will be constructed.

* 'auto': The default; combine with separator if there are multiple
`values` columns, otherwise just use the `on_columns` names.
* 'always_combine': Always combine the `values` columns' names with
the `on_columns` names.

Returns
-------
Expand Down Expand Up @@ -9638,6 +9646,7 @@ def pivot(
aggregate_function=aggregate_function,
maintain_order=maintain_order,
separator=separator,
column_naming=column_naming,
)
.collect(optimizations=QueryOptFlags._eager())
)
Expand Down
9 changes: 9 additions & 0 deletions py-polars/src/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -8089,6 +8089,7 @@ def pivot(
aggregate_function: PivotAgg | Expr | None = None,
maintain_order: bool = False,
separator: str = "_",
column_naming: Literal["auto", "always_combine"] = "auto",
) -> LazyFrame:
"""
Create a spreadsheet-style pivot table as a DataFrame.
Expand Down Expand Up @@ -8123,6 +8124,13 @@ def pivot(
separator
Used as separator/delimiter in generated column names in case of multiple
`values` columns.
column_naming : {'auto', 'always_combine'}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

How resulting column names will be constructed.

* 'auto': The default; combine with separator if there are multiple
`values` columns, otherwise just use the `on_columns` names.
* 'always_combine': Always combine the `values` columns' names with
the `on_columns` names.

Returns
-------
Expand Down Expand Up @@ -8315,6 +8323,7 @@ def pivot(
agg=agg._pyexpr,
maintain_order=maintain_order,
separator=separator,
column_naming=column_naming,
)
)

Expand Down
25 changes: 22 additions & 3 deletions py-polars/tests/unit/operations/test_pivot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import date, timedelta
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import pytest

Expand Down Expand Up @@ -35,8 +35,25 @@ def test_pivot() -> None:
)
assert_frame_equal(result, expected)

# Next, with column naming that combines value column with on columns:
result = df.pivot(
"bar", values="N", aggregate_function=None, column_naming="always_combine"
)

def test_pivot_no_values() -> None:
expected = pl.DataFrame(
[
("A", 1, 2, None, None, None),
("B", None, None, 2, 4, None),
("C", None, None, None, None, 2),
],
schema=["foo", "N_k", "N_l", "N_m", "N_n", "N_o"],
orient="row",
)
assert_frame_equal(result, expected)


@pytest.mark.parametrize("column_naming", ["auto", "always_combine"])
def test_pivot_no_values(column_naming: Literal["auto", "always_combine"]) -> None:
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
Expand All @@ -45,7 +62,9 @@ def test_pivot_no_values() -> None:
"N2": [1, 2, 2, 4, 2],
}
)
result = df.pivot(on="bar", index="foo", aggregate_function=None)
result = df.pivot(
on="bar", index="foo", aggregate_function=None, column_naming=column_naming
)
expected = pl.DataFrame(
{
"foo": ["A", "B", "C"],
Expand Down
Loading