Skip to content

Commit a00dbb7

Browse files
committed
more partition_by prep
1 parent 15c87ea commit a00dbb7

File tree

6 files changed

+116
-6
lines changed

6 files changed

+116
-6
lines changed

narwhals/_plan/_expansion.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
IntoFrozenSchema,
7373
freeze_schema,
7474
)
75+
from narwhals._utils import check_column_names_are_unique
7576
from narwhals.dtypes import DType
7677
from narwhals.exceptions import ComputeError, InvalidOperationError
7778

@@ -156,7 +157,7 @@ def with_multiple_columns(self) -> ExpansionFlags:
156157
def prepare_projection(
157158
exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema
158159
) -> tuple[Seq[NamedIR], FrozenSchema]:
159-
"""Expand IRs into named column selections.
160+
"""Expand IRs into named column projections.
160161
161162
**Primary entry-point**, for `select`, `with_columns`,
162163
and any other context that requires resolving expression names.
@@ -173,13 +174,33 @@ def prepare_projection(
173174
return named_irs, frozen_schema
174175

175176

177+
def expand_selector_irs_names(
178+
selectors: Sequence[SelectorIR],
179+
/,
180+
keys: GroupByKeys = (),
181+
*,
182+
schema: IntoFrozenSchema,
183+
) -> OutputNames:
184+
"""Expand selector-only input into the column names that match.
185+
186+
Similar to `prepare_projection`, but intended for allowing a subset of `Expr` and all `Selector`s
187+
to be used in more places like `DataFrame.{drop,sort,partition_by}`.
188+
189+
Arguments:
190+
selectors: IRs that **only** contain subclasses of `SelectorIR`.
191+
keys: Names of `group_by` columns.
192+
schema: Scope to expand multi-column selectors in.
193+
"""
194+
frozen_schema = freeze_schema(schema)
195+
names = tuple(_iter_expand_selector_names(selectors, keys, schema=frozen_schema))
196+
return _ensure_valid_output_names(names, frozen_schema)
197+
198+
176199
def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]:
177200
if len(exprs) != len(names):
178201
msg = f"zip length mismatch: {len(exprs)} != {len(names)}"
179202
raise ValueError(msg)
180-
return tuple(
181-
NamedIR(expr=remove_alias(ir), name=name) for ir, name in zip(exprs, names)
182-
)
203+
return tuple(ir.named_ir(name, remove_alias(e)) for e, name in zip(exprs, names))
183204

184205

185206
def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames:
@@ -191,13 +212,40 @@ def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames:
191212
return output_names
192213

193214

215+
def _ensure_valid_output_names(names: Seq[str], schema: FrozenSchema) -> OutputNames:
216+
"""Selector-only variant of `ensure_valid_exprs`."""
217+
check_column_names_are_unique(names)
218+
output_names = names
219+
if not (set(schema.names).issuperset(output_names)):
220+
raise column_not_found_error(output_names, schema)
221+
return output_names
222+
223+
194224
def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames:
195225
names = tuple(e.meta.output_name() for e in exprs)
196226
if len(names) != len(set(names)):
197227
raise duplicate_error(exprs)
198228
return names
199229

200230

231+
def _ensure_columns(expr: ExprIR, /) -> Columns:
232+
if not isinstance(expr, Columns):
233+
msg = f"Expected only column selections here, but got {expr!r}"
234+
raise NotImplementedError(msg)
235+
return expr
236+
237+
238+
def _iter_expand_selector_names(
239+
selectors: Iterable[SelectorIR], /, keys: GroupByKeys = (), *, schema: FrozenSchema
240+
) -> Iterator[str]:
241+
for selector in selectors:
242+
names = _ensure_columns(replace_selector(selector, schema=schema)).names
243+
if keys:
244+
yield from (name for name in names if name not in keys)
245+
else:
246+
yield from names
247+
248+
201249
# NOTE: Recursive for all `input` expressions which themselves contain `Seq[ExprIR]`
202250
def rewrite_projections(
203251
input: Seq[ExprIR], /, keys: GroupByKeys = (), *, schema: FrozenSchema

narwhals/_plan/_guards.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
from narwhals._plan.compliant.series import CompliantSeries
1616
from narwhals._plan.expr import Expr
1717
from narwhals._plan.series import Series
18-
from narwhals._plan.typing import IntoExprColumn, NativeSeriesT, Seq
18+
from narwhals._plan.typing import (
19+
ColumnNameOrSelector,
20+
IntoExprColumn,
21+
NativeSeriesT,
22+
Seq,
23+
)
1924
from narwhals.typing import NonNestedLiteral
2025

2126
T = TypeVar("T")
@@ -75,6 +80,13 @@ def is_into_expr_column(obj: Any) -> TypeIs[IntoExprColumn]:
7580
return isinstance(obj, (str, _expr().Expr, _series().Series))
7681

7782

83+
def is_column_name_or_selector(
84+
obj: Any, *, allow_expr: bool = False
85+
) -> TypeIs[ColumnNameOrSelector]:
86+
tps = (str, _expr().Selector) if not allow_expr else (str, _expr().Expr)
87+
return isinstance(obj, tps)
88+
89+
7890
def is_compliant_series(
7991
obj: CompliantSeries[NativeSeriesT] | Any,
8092
) -> TypeIs[CompliantSeries[NativeSeriesT]]:

narwhals/_plan/_parse.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING
88

99
from narwhals._plan._guards import (
10+
is_column_name_or_selector,
1011
is_expr,
1112
is_into_expr_column,
1213
is_iterable_reject,
@@ -132,6 +133,8 @@ def parse_into_expr_ir(
132133
return expr._ir
133134

134135

136+
# NOTE: Might need to add `require_all`, since selectors are created indirectly from `str`
137+
# here, but use set semantics
135138
def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR:
136139
if is_selector(input):
137140
selector = input
@@ -198,6 +201,34 @@ def _parse_sort_by_into_iter_expr_ir(
198201
yield e
199202

200203

204+
def parse_into_seq_of_selector_ir(
205+
first_input: OneOrIterable[ColumnNameOrSelector], *more_inputs: ColumnNameOrSelector
206+
) -> Seq[SelectorIR]:
207+
return tuple(_parse_into_iter_selector_ir(first_input, more_inputs))
208+
209+
210+
def _parse_into_iter_selector_ir(
211+
first_input: OneOrIterable[ColumnNameOrSelector],
212+
more_inputs: tuple[ColumnNameOrSelector, ...],
213+
/,
214+
) -> Iterator[SelectorIR]:
215+
if is_column_name_or_selector(first_input) and not more_inputs:
216+
yield parse_into_selector_ir(first_input)
217+
return
218+
219+
if not _is_empty_sequence(first_input):
220+
if _is_iterable(first_input) and not isinstance(first_input, str):
221+
if more_inputs:
222+
raise invalid_into_expr_error(first_input, more_inputs, {})
223+
else:
224+
for into in first_input: # type: ignore[var-annotated]
225+
yield parse_into_selector_ir(into)
226+
else:
227+
yield parse_into_selector_ir(first_input)
228+
for into in more_inputs:
229+
yield parse_into_selector_ir(into)
230+
231+
201232
def _parse_into_iter_expr_ir(
202233
first_input: OneOrIterable[IntoExpr],
203234
*more_inputs: IntoExpr | list[Any],

narwhals/_plan/arrow/dataframe.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,8 @@ def filter(self, predicate: NamedIR) -> Self:
171171
else:
172172
mask = acero.lit(resolved.native)
173173
return self._with_native(self.native.filter(mask))
174+
175+
def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[Self]:
176+
"""Review https://github.com/pola-rs/polars/blob/870f0e01811b8b0cf9b846ded9d97685f143d27c/crates/polars-core/src/frame/mod.rs#L3225-L3284."""
177+
msg = "TODO: `ArrowDataFrame.partition_by`"
178+
raise NotImplementedError(msg)

narwhals/_plan/compliant/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def join(
129129
suffix: str = "_right",
130130
) -> Self: ...
131131
def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ...
132+
def partition_by(
133+
self, by: Sequence[str], *, include_key: bool = True
134+
) -> list[Self]: ...
132135
def row(self, index: int) -> tuple[Any, ...]: ...
133136
@overload
134137
def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ...

narwhals/_plan/dataframe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, get_args, overload
44

55
from narwhals._plan import _parse
6-
from narwhals._plan._expansion import prepare_projection
6+
from narwhals._plan._expansion import expand_selector_irs_names, prepare_projection
77
from narwhals._plan.common import ensure_seq_str, temp
88
from narwhals._plan.group_by import GroupBy, Grouped
99
from narwhals._plan.options import SortMultipleOptions
@@ -258,6 +258,17 @@ def filter(
258258
raise ValueError(msg)
259259
return self._with_compliant(self._compliant.filter(named_irs[0]))
260260

261+
def partition_by(
262+
self,
263+
by: OneOrIterable[ColumnNameOrSelector],
264+
*more_by: ColumnNameOrSelector,
265+
include_key: bool = True,
266+
) -> list[Self]:
267+
by_selectors = _parse.parse_into_seq_of_selector_ir(by, *more_by)
268+
names = expand_selector_irs_names(by_selectors, schema=self)
269+
partitions = self._compliant.partition_by(names, include_key=include_key)
270+
return [self._with_compliant(p) for p in partitions]
271+
261272

262273
def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]:
263274
return obj in {"inner", "left", "full", "cross", "anti", "semi"}

0 commit comments

Comments
 (0)