Skip to content

Commit e0d1a00

Browse files
committed
feat(expr-ir): Implement ArrowDataFrame.partition_by
Supports selector input for partitions
1 parent a00dbb7 commit e0d1a00

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

narwhals/_plan/arrow/dataframe.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import pyarrow.compute as pc # ignore-banned-import
1010

1111
from narwhals._arrow.utils import native_to_narwhals_dtype
12-
from narwhals._plan.arrow import acero, functions as fn
12+
from narwhals._plan.arrow import acero, functions as fn, group_by
1313
from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar
1414
from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy
1515
from narwhals._plan.arrow.series import ArrowSeries as Series
16+
from narwhals._plan.common import temp
1617
from narwhals._plan.compliant.dataframe import EagerDataFrame
1718
from narwhals._plan.compliant.typing import namespace
1819
from narwhals._plan.expressions import NamedIR
@@ -172,7 +173,25 @@ def filter(self, predicate: NamedIR) -> Self:
172173
mask = acero.lit(resolved.native)
173174
return self._with_native(self.native.filter(mask))
174175

176+
# TODO @dangotbanned: Clean this up after getting more tests in place
175177
def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[Self]:
176-
"""Review https://github.com/pola-rs/polars/blob/870f0e01811b8b0cf9b846ded9d97685f143d27c/crates/polars-core/src/frame/mod.rs#L3225-L3284."""
177-
msg = "TODO: `ArrowDataFrame.partition_by`"
178-
raise NotImplementedError(msg)
178+
original_names = self.columns
179+
temp_name = temp.column_name(original_names)
180+
native = self.native
181+
composite_values = group_by.concat_str(acero.select_names_table(native, by))
182+
re_keyed = native.add_column(0, temp_name, composite_values)
183+
source = acero.table_source(re_keyed)
184+
if include_key:
185+
keep = original_names
186+
else:
187+
ignore = {*by, temp_name}
188+
keep = [name for name in original_names if name not in ignore]
189+
select = acero.select_names(keep)
190+
key = acero.col(temp_name)
191+
# Need to iterate over the whole thing, so py_list first should be faster
192+
partitions = (
193+
acero.declare(source, acero.filter(key == v), select)
194+
for v in composite_values.unique().to_pylist()
195+
)
196+
from_native = self._with_native
197+
return [from_native(decl.to_table()) for decl in partitions]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any
4+
5+
import pytest
6+
7+
import narwhals as nw
8+
from narwhals._plan import selectors as ncs
9+
from narwhals._utils import zip_strict
10+
from tests.plan.utils import assert_equal_data, dataframe
11+
12+
if TYPE_CHECKING:
13+
from narwhals._plan.typing import ColumnNameOrSelector
14+
from tests.conftest import Data
15+
16+
17+
@pytest.fixture
18+
def data() -> Data:
19+
return {"a": ["a", "b", "a", "b", "c"], "b": [1, 2, 1, 3, 3], "c": [5, 4, 3, 2, 1]}
20+
21+
22+
@pytest.mark.parametrize(
23+
("include_key", "expected"),
24+
[
25+
(
26+
True,
27+
[
28+
{"a": ["a", "a"], "b": [1, 1], "c": [5, 3]},
29+
{"a": ["b", "b"], "b": [2, 3], "c": [4, 2]},
30+
{"a": ["c"], "b": [3], "c": [1]},
31+
],
32+
),
33+
(
34+
False,
35+
[
36+
{"b": [1, 1], "c": [5, 3]},
37+
{"b": [2, 3], "c": [4, 2]},
38+
{"b": [3], "c": [1]},
39+
],
40+
),
41+
],
42+
ids=["include_key", "exclude_key"],
43+
)
44+
@pytest.mark.parametrize(
45+
"by",
46+
["a", ncs.string(), ncs.matches("a"), ncs.by_name("a"), ncs.by_dtype(nw.String)],
47+
ids=["str", "ncs.string", "ncs.matches", "ncs.by_name", "ncs.by_dtype"],
48+
)
49+
def test_partition_by_single(
50+
data: Data, by: ColumnNameOrSelector, *, include_key: bool, expected: Any
51+
) -> None:
52+
df = dataframe(data)
53+
results = df.partition_by(by, include_key=include_key)
54+
for df, expect in zip_strict(results, expected):
55+
assert_equal_data(df, expect)

0 commit comments

Comments
 (0)