Skip to content

Commit d47a7fa

Browse files
authored
feat(expr-ir): Add BaseFrame.unpivot (#3368)
1 parent 158eda2 commit d47a7fa

File tree

4 files changed

+144
-0
lines changed

4 files changed

+144
-0
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,27 @@ def _unique(
153153
unique = _unique
154154
unique_by = _unique
155155

156+
def unpivot(
157+
self,
158+
on: Sequence[str] | None,
159+
index: Sequence[str] | None,
160+
*,
161+
variable_name: str = "variable",
162+
value_name: str = "value",
163+
) -> Self:
164+
n = len(self)
165+
index = [] if index is None else list(index)
166+
on_ = (c for c in self.columns if c not in index) if on is None else iter(on)
167+
index_cols = self.native.select(index)
168+
column = self.native.column
169+
tables = (
170+
index_cols.append_column(variable_name, fn.repeat(name, n)).append_column(
171+
value_name, column(name)
172+
)
173+
for name in on_
174+
)
175+
return self._with_native(fn.concat_tables(tables, "permissive"))
176+
156177
def with_row_index(self, name: str) -> Self:
157178
return self._with_native(self.native.add_column(0, name, fn.int_range(len(self))))
158179

narwhals/_plan/compliant/dataframe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def unique_by(
7979
order_by: Sequence[str],
8080
keep: UniqueKeepStrategy = "any",
8181
) -> Self: ...
82+
def unpivot(
83+
self,
84+
on: Sequence[str] | None,
85+
index: Sequence[str] | None,
86+
*,
87+
variable_name: str = "variable",
88+
value_name: str = "value",
89+
) -> Self: ...
8290
def with_columns(self, irs: Seq[NamedIR]) -> Self: ...
8391
def with_row_index_by(
8492
self, name: str, order_by: Sequence[str], *, nulls_last: bool = False

narwhals/_plan/dataframe.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,29 @@ def rename(self, mapping: Mapping[str, str]) -> Self:
159159
def collect_schema(self) -> Schema:
160160
return self.schema
161161

162+
def unpivot(
163+
self,
164+
on: OneOrIterable[ColumnNameOrSelector] | None = None,
165+
*,
166+
index: OneOrIterable[ColumnNameOrSelector] | None = None,
167+
variable_name: str = "variable",
168+
value_name: str = "value",
169+
) -> Self:
170+
on_: Seq[str] | None = None
171+
index_: Seq[str] | None = None
172+
schema = self.schema
173+
if on is not None:
174+
s_irs = _parse.parse_into_seq_of_selector_ir(on)
175+
on_ = expand_selector_irs_names(s_irs, schema=schema, require_any=True)
176+
if index is not None:
177+
s_irs = _parse.parse_into_seq_of_selector_ir(index)
178+
index_ = expand_selector_irs_names(s_irs, schema=schema, require_any=True)
179+
return self._with_compliant(
180+
self._compliant.unpivot(
181+
on_, index_, variable_name=variable_name, value_name=value_name
182+
)
183+
)
184+
162185
def with_row_index(
163186
self, name: str = "index", *, order_by: OneOrIterable[ColumnNameOrSelector]
164187
) -> Self:

tests/plan/unpivot_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Final
4+
5+
import pytest
6+
7+
import narwhals as nw
8+
from narwhals import _plan as nwp
9+
from narwhals._plan import selectors as ncs
10+
from tests.plan.utils import assert_equal_data, dataframe
11+
from tests.utils import PYARROW_VERSION
12+
13+
if TYPE_CHECKING:
14+
from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable
15+
from tests.conftest import Data
16+
17+
18+
@pytest.fixture(scope="module")
19+
def data() -> Data:
20+
return {"a": [7, 8, 9], "b": [1, 3, 5], "c": [2, 4, 6]}
21+
22+
23+
A: Final = [7, 8, 9]
24+
B: Final = [1, 3, 5]
25+
C: Final = [2, 4, 6]
26+
27+
VAR = "variable"
28+
VALUE = "value"
29+
30+
a = ncs.first()
31+
b = ncs.by_name("b")
32+
c = ncs.last()
33+
34+
35+
@pytest.mark.parametrize(
36+
("on", "index", "expected"),
37+
[
38+
("b", [a], {"a": A, VAR: ["b", "b", "b"], VALUE: B}),
39+
(
40+
["b", c],
41+
a,
42+
{"a": [*A, *A], VAR: ["b", "b", "b", "c", "c", "c"], VALUE: [*B, *C]},
43+
),
44+
(
45+
None,
46+
["a"],
47+
{"a": [*A, *A], VAR: ["b", "b", "b", "c", "c", "c"], VALUE: [*B, *C]},
48+
),
49+
([b | c], None, {VAR: ["b", "b", "b", "c", "c", "c"], VALUE: [*B, *C]}),
50+
(
51+
None,
52+
None,
53+
{VAR: ["a", "a", "a", "b", "b", "b", "c", "c", "c"], VALUE: [*A, *B, *C]},
54+
),
55+
],
56+
)
57+
def test_unpivot(
58+
data: Data,
59+
on: OneOrIterable[ColumnNameOrSelector] | None,
60+
index: OneOrIterable[ColumnNameOrSelector] | None,
61+
expected: Data,
62+
) -> None:
63+
sort_columns = [VAR] if index is None else [VAR, "a"]
64+
result = dataframe(data).unpivot(on, index=index).sort(sort_columns)
65+
assert_equal_data(result, expected)
66+
67+
68+
@pytest.mark.parametrize(
69+
("variable_name", "value_name"),
70+
[
71+
("", "custom_value_name"),
72+
("custom_variable_name", ""),
73+
("custom_variable_name", "custom_value_name"),
74+
],
75+
)
76+
def test_unpivot_var_value_names(data: Data, variable_name: str, value_name: str) -> None:
77+
result = dataframe(data).unpivot(
78+
~ncs.first(), index=["a"], variable_name=variable_name, value_name=value_name
79+
)
80+
assert result.collect_schema().names()[-2:] == [variable_name, value_name]
81+
82+
83+
def test_unpivot_default_var_value_names(data: Data) -> None:
84+
result = dataframe(data).unpivot(nwp.nth(1, 2).meta.as_selector(), index=ncs.first())
85+
assert result.collect_schema().names()[-2:] == [VAR, VALUE]
86+
87+
88+
@pytest.mark.xfail(PYARROW_VERSION < (14, 0, 0), reason="pyarrow<14")
89+
def test_unpivot_mixed_types() -> None:
90+
df = dataframe({"idx": [0, 1], "a": [1, 2], "b": [1.5, 2.5]})
91+
result = df.unpivot(["a", "b"], index="idx")
92+
assert result.collect_schema().dtypes() == [nw.Int64(), nw.String(), nw.Float64()]

0 commit comments

Comments
 (0)