diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index 20a886c3a3..40210374ac 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -40,10 +40,10 @@ from collections import deque from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from narwhals._plan import common, meta -from narwhals._plan._guards import is_horizontal_reduction +from narwhals._plan import common, expressions as ir, meta +from narwhals._plan._guards import is_horizontal_reduction, is_window_expr from narwhals._plan._immutable import Immutable from narwhals._plan.exceptions import ( column_index_error, @@ -72,6 +72,7 @@ IntoFrozenSchema, freeze_schema, ) +from narwhals._utils import check_column_names_are_unique from narwhals.dtypes import DType from narwhals.exceptions import ComputeError, InvalidOperationError @@ -156,7 +157,7 @@ def with_multiple_columns(self) -> ExpansionFlags: def prepare_projection( exprs: Sequence[ExprIR], /, keys: GroupByKeys = (), *, schema: IntoFrozenSchema ) -> tuple[Seq[NamedIR], FrozenSchema]: - """Expand IRs into named column selections. + """Expand IRs into named column projections. **Primary entry-point**, for `select`, `with_columns`, and any other context that requires resolving expression names. @@ -173,13 +174,33 @@ def prepare_projection( return named_irs, frozen_schema +def expand_selector_irs_names( + selectors: Sequence[SelectorIR], + /, + keys: GroupByKeys = (), + *, + schema: IntoFrozenSchema, +) -> OutputNames: + """Expand selector-only input into the column names that match. + + Similar to `prepare_projection`, but intended for allowing a subset of `Expr` and all `Selector`s + to be used in more places like `DataFrame.{drop,sort,partition_by}`. + + Arguments: + selectors: IRs that **only** contain subclasses of `SelectorIR`. + keys: Names of `group_by` columns. + schema: Scope to expand multi-column selectors in. + """ + frozen_schema = freeze_schema(schema) + names = tuple(_iter_expand_selector_names(selectors, keys, schema=frozen_schema)) + return _ensure_valid_output_names(names, frozen_schema) + + def into_named_irs(exprs: Seq[ExprIR], names: OutputNames) -> Seq[NamedIR]: if len(exprs) != len(names): msg = f"zip length mismatch: {len(exprs)} != {len(names)}" raise ValueError(msg) - return tuple( - NamedIR(expr=remove_alias(ir), name=name) for ir, name in zip(exprs, names) - ) + return tuple(ir.named_ir(name, remove_alias(e)) for e, name in zip(exprs, names)) def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: @@ -191,6 +212,15 @@ def ensure_valid_exprs(exprs: Seq[ExprIR], schema: FrozenSchema) -> OutputNames: return output_names +def _ensure_valid_output_names(names: Seq[str], schema: FrozenSchema) -> OutputNames: + """Selector-only variant of `ensure_valid_exprs`.""" + check_column_names_are_unique(names) + output_names = names + if not (set(schema.names).issuperset(output_names)): + raise column_not_found_error(output_names, schema) + return output_names + + def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: names = tuple(e.meta.output_name() for e in exprs) if len(names) != len(set(names)): @@ -198,12 +228,64 @@ def _ensure_output_names_unique(exprs: Seq[ExprIR]) -> OutputNames: return names -def expand_function_inputs(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: +def _ensure_columns(expr: ExprIR, /) -> Columns: + if not isinstance(expr, Columns): + msg = f"Expected only column selections here, but got {expr!r}" + raise NotImplementedError(msg) + return expr + + +def _iter_expand_selector_names( + selectors: Iterable[SelectorIR], /, keys: GroupByKeys = (), *, schema: FrozenSchema +) -> Iterator[str]: + for selector in selectors: + names = _ensure_columns(replace_selector(selector, schema=schema)).names + if keys: + yield from (name for name in names if name not in keys) + else: + yield from names + + +# NOTE: Recursive for all `input` expressions which themselves contain `Seq[ExprIR]` +def rewrite_projections( + input: Seq[ExprIR], /, keys: GroupByKeys = (), *, schema: FrozenSchema +) -> Seq[ExprIR]: + result: deque[ExprIR] = deque() + for expr in input: + expanded = _expand_nested_nodes(expr, schema=schema) + flags = ExpansionFlags.from_ir(expanded) + if flags.has_selector: + expanded = replace_selector(expanded, schema=schema) + flags = flags.with_multiple_columns() + result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags)) + return tuple(result) + + +def _expand_nested_nodes(origin: ExprIR, /, *, schema: FrozenSchema) -> ExprIR: + """Adapted from [`expand_function_inputs`]. + + Added additional cases for nodes that *also* need to be expanded in the same way. + + [`expand_function_inputs`]: https://github.com/pola-rs/polars/blob/df4d21c30c2b383b651e194f8263244f2afaeda3/crates/polars-plan/src/plans/conversion/expr_expansion.rs#L557-L581 + """ + rewrite = rewrite_projections + def fn(child: ExprIR, /) -> ExprIR: + if not isinstance(child, (ir.FunctionExpr, ir.WindowExpr, ir.SortBy)): + return child + expanded: dict[str, Any] = {} if is_horizontal_reduction(child): - rewrites = rewrite_projections(child.input, schema=schema) - return common.replace(child, input=rewrites) - return child + expanded["input"] = rewrite(child.input, schema=schema) + elif is_window_expr(child): + if partition_by := child.partition_by: + expanded["partition_by"] = rewrite(partition_by, schema=schema) + if isinstance(child, ir.OrderedWindowExpr): + expanded["order_by"] = rewrite(child.order_by, schema=schema) + elif isinstance(child, ir.SortBy): + expanded["by"] = rewrite(child.by, schema=schema) + if not expanded: + return child + return common.replace(child, **expanded) return origin.map_ir(fn) @@ -271,24 +353,6 @@ def expand_selector(selector: SelectorIR, schema: FrozenSchema) -> Columns: return cols(*(k for k, v in schema.items() if matches(selector, k, v))) -def rewrite_projections( - input: Seq[ExprIR], # `FunctionExpr.input` - /, - keys: GroupByKeys = (), - *, - schema: FrozenSchema, -) -> Seq[ExprIR]: - result: deque[ExprIR] = deque() - for expr in input: - expanded = expand_function_inputs(expr, schema=schema) - flags = ExpansionFlags.from_ir(expanded) - if flags.has_selector: - expanded = replace_selector(expanded, schema=schema) - flags = flags.with_multiple_columns() - result.extend(iter_replace(expanded, keys, col_names=schema.names, flags=flags)) - return tuple(result) - - def iter_replace( origin: ExprIR, /, diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 8f66f2d9fa..5ba9206f72 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -8,6 +8,7 @@ from narwhals._plan.common import replace from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ExprIRT +from narwhals.exceptions import InvalidOperationError from narwhals.utils import Version if TYPE_CHECKING: @@ -59,6 +60,10 @@ def to_narwhals(self, version: Version = Version.MAIN) -> Expr: tp = expr.Expr if version is Version.MAIN else expr.ExprV1 return tp._from_ir(self) + def to_selector_ir(self) -> SelectorIR: + msg = f"cannot turn `{self!r}` into a selector" + raise InvalidOperationError(msg) + @property def is_scalar(self) -> bool: return False @@ -201,6 +206,9 @@ def matches_column(self, name: str, dtype: DType) -> bool: """ raise NotImplementedError(type(self)) + def to_selector_ir(self) -> Self: + return self + class NamedIR(Immutable, Generic[ExprIRT]): """Post-projection expansion wrapper for `ExprIR`. diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index 780070f038..ad1ba5c931 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -11,11 +11,16 @@ if TYPE_CHECKING: from typing_extensions import TypeIs - from narwhals._plan import expressions as ir + from narwhals._plan import expr, expressions as ir from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.expr import Expr from narwhals._plan.series import Series - from narwhals._plan.typing import IntoExprColumn, NativeSeriesT, Seq + from narwhals._plan.typing import ( + ColumnNameOrSelector, + IntoExprColumn, + NativeSeriesT, + Seq, + ) from narwhals.typing import NonNestedLiteral T = TypeVar("T") @@ -58,6 +63,10 @@ def is_expr(obj: Any) -> TypeIs[Expr]: return isinstance(obj, _expr().Expr) +def is_selector(obj: Any) -> TypeIs[expr.Selector]: + return isinstance(obj, _expr().Selector) + + def is_column(obj: Any) -> TypeIs[Expr]: """Indicate if the given object is a basic/unaliased column.""" return is_expr(obj) and obj.meta.is_column() @@ -71,6 +80,13 @@ def is_into_expr_column(obj: Any) -> TypeIs[IntoExprColumn]: return isinstance(obj, (str, _expr().Expr, _series().Series)) +def is_column_name_or_selector( + obj: Any, *, allow_expr: bool = False +) -> TypeIs[ColumnNameOrSelector]: + tps = (str, _expr().Selector) if not allow_expr else (str, _expr().Expr) + return isinstance(obj, tps) + + def is_compliant_series( obj: CompliantSeries[NativeSeriesT] | Any, ) -> TypeIs[CompliantSeries[NativeSeriesT]]: diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index c2a5cc7c2f..db2abe15da 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -6,12 +6,19 @@ from itertools import chain from typing import TYPE_CHECKING -from narwhals._plan._guards import is_expr, is_into_expr_column, is_iterable_reject +from narwhals._plan._guards import ( + is_column_name_or_selector, + is_expr, + is_into_expr_column, + is_iterable_reject, + is_selector, +) from narwhals._plan.exceptions import ( invalid_into_expr_error, is_iterable_pandas_error, is_iterable_polars_error, ) +from narwhals._utils import qualified_type_name from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series from narwhals.exceptions import InvalidOperationError @@ -22,8 +29,10 @@ import polars as pl from typing_extensions import TypeAlias, TypeIs - from narwhals._plan.expressions import ExprIR + from narwhals._plan.expr import Expr + from narwhals._plan.expressions import ExprIR, SelectorIR from narwhals._plan.typing import ( + ColumnNameOrSelector, IntoExpr, IntoExprColumn, OneOrIterable, @@ -124,6 +133,23 @@ def parse_into_expr_ir( return expr._ir +# NOTE: Might need to add `require_all`, since selectors are created indirectly from `str` +# here, but use set semantics +def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR: + if is_selector(input): + selector = input + elif isinstance(input, str): + from narwhals._plan import selectors as cs + + selector = cs.by_name(input) + elif is_expr(input): + selector = input.meta.as_selector() + else: + msg = f"cannot turn {qualified_type_name(input)!r} into selector" + raise TypeError(msg) + return selector._ir + + def parse_into_seq_of_expr_ir( first_input: OneOrIterable[IntoExpr] = (), *more_inputs: IntoExpr | _RaisesInvalidIntoExprError, @@ -175,6 +201,34 @@ def _parse_sort_by_into_iter_expr_ir( yield e +def parse_into_seq_of_selector_ir( + first_input: OneOrIterable[ColumnNameOrSelector], *more_inputs: ColumnNameOrSelector +) -> Seq[SelectorIR]: + return tuple(_parse_into_iter_selector_ir(first_input, more_inputs)) + + +def _parse_into_iter_selector_ir( + first_input: OneOrIterable[ColumnNameOrSelector], + more_inputs: tuple[ColumnNameOrSelector, ...], + /, +) -> Iterator[SelectorIR]: + if is_column_name_or_selector(first_input) and not more_inputs: + yield parse_into_selector_ir(first_input) + return + + if not _is_empty_sequence(first_input): + if _is_iterable(first_input) and not isinstance(first_input, str): + if more_inputs: + raise invalid_into_expr_error(first_input, more_inputs, {}) + else: + for into in first_input: # type: ignore[var-annotated] + yield parse_into_selector_ir(into) + else: + yield parse_into_selector_ir(first_input) + for into in more_inputs: + yield parse_into_selector_ir(into) + + def _parse_into_iter_expr_ir( first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr | list[Any], diff --git a/narwhals/_plan/arrow/acero.py b/narwhals/_plan/arrow/acero.py index f99fad5289..52f4ef33b9 100644 --- a/narwhals/_plan/arrow/acero.py +++ b/narwhals/_plan/arrow/acero.py @@ -246,6 +246,15 @@ def prepend_column(native: pa.Table, name: str, values: IntoExpr) -> Decl: return _add_column(native, 0, name, values) +def _union(declarations: Iterable[Decl], /) -> Decl: + """[`union`] merges multiple data streams with the same schema into one, similar to a `SQL UNION ALL` clause. + + [`union`]: https://arrow.apache.org/docs/cpp/acero/user_guide.html#union + """ + decls: Incomplete = declarations + return Decl("union", pac.ExecNodeOptions(), decls) + + def _order_by( sort_keys: Iterable[tuple[str, Order]] = (), *, diff --git a/narwhals/_plan/arrow/dataframe.py b/narwhals/_plan/arrow/dataframe.py index 15b9dc80c0..73a3d4e8c1 100644 --- a/narwhals/_plan/arrow/dataframe.py +++ b/narwhals/_plan/arrow/dataframe.py @@ -11,7 +11,7 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._plan.arrow import acero, functions as fn from narwhals._plan.arrow.expr import ArrowExpr as Expr, ArrowScalar as Scalar -from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy +from narwhals._plan.arrow.group_by import ArrowGroupBy as GroupBy, partition_by from narwhals._plan.arrow.series import ArrowSeries as Series from narwhals._plan.compliant.dataframe import EagerDataFrame from narwhals._plan.compliant.typing import namespace @@ -23,7 +23,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping, Sequence - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from narwhals._arrow.typing import ChunkedArrayAny from narwhals._plan.arrow.namespace import ArrowNamespace @@ -33,6 +33,8 @@ from narwhals.dtypes import DType from narwhals.typing import IntoSchema +Incomplete: TypeAlias = Any + class ArrowDataFrame(EagerDataFrame[Series, "pa.Table", "ChunkedArrayAny"]): implementation = Implementation.PYARROW @@ -152,7 +154,7 @@ def join( *, how: NonCrossJoinStrategy, left_on: Sequence[str], - right_on: Sequence[str], + right_on: Sequence[str] = (), suffix: str = "_right", ) -> Self: left, right = self.native, other.native @@ -171,3 +173,8 @@ def filter(self, predicate: NamedIR) -> Self: else: mask = acero.lit(resolved.native) return self._with_native(self.native.filter(mask)) + + def partition_by(self, by: Sequence[str], *, include_key: bool = True) -> list[Self]: + from_native = self._with_native + partitions = partition_by(self.native, by, include_key=include_key) + return [from_native(df) for df in partitions] diff --git a/narwhals/_plan/arrow/expr.py b/narwhals/_plan/arrow/expr.py index fb2bad1479..7187152d25 100644 --- a/narwhals/_plan/arrow/expr.py +++ b/narwhals/_plan/arrow/expr.py @@ -55,7 +55,14 @@ Not, ) from narwhals._plan.expressions.expr import BinaryExpr, FunctionExpr - from narwhals._plan.expressions.functions import Abs, FillNull, Pow + from narwhals._plan.expressions.functions import ( + Abs, + CumAgg, + Diff, + FillNull, + Pow, + Shift, + ) from narwhals.typing import Into1DArray, IntoDType, PythonLiteral Expr: TypeAlias = "ArrowExpr" @@ -322,13 +329,31 @@ def min(self, node: Min, frame: Frame, name: str) -> Scalar: return self._with_native(result, name) # TODO @dangotbanned: top-level, complex-ish nodes - # - [ ] `over`/`_ordered` (with partitions) requires `group_by`, `join` - # - [x] `over_ordered` alone should be possible w/ the current API - # - [x] `map_batches` is defined in `EagerExpr`, might be simpler here than on main + # - [ ] Over + # - [x] `over_ordered` + # - [x] `group_by`, `join` + # - [!] `over` + # - [ ] `over_ordered` (with partitions) + # - [ ] `map_batches` + # - [x] elementwise + # - [ ] scalar # - [ ] `rolling_expr` has 4 variants def over(self, node: ir.WindowExpr, frame: Frame, name: str) -> Self: - raise NotImplementedError + resolved = ( + frame._grouper.by_irs(*node.partition_by) + # TODO @dangotbanned: Clean this up so the re-alias isn't needed + .agg_irs(node.expr.alias(name)) + .resolve(frame) + ) + by_names = resolved.key_names + result = ( + frame.select_names(*by_names) + .join(resolved.evaluate(frame), how="left", left_on=by_names) + .get_column(name) + .native + ) + return self._with_native(result, name) def over_ordered( self, node: ir.OrderedWindowExpr, frame: Frame, name: str @@ -372,6 +397,24 @@ def map_batches(self, node: ir.AnonymousExpr, frame: Frame, name: str) -> Self: def rolling_expr(self, node: ir.RollingExpr, frame: Frame, name: str) -> Self: raise NotImplementedError + def shift(self, node: ir.FunctionExpr[Shift], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.shift(series.native, node.function.n), name) + + def diff(self, node: ir.FunctionExpr[Diff], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.diff(series.native), name) + + def _cumulative(self, node: ir.FunctionExpr[CumAgg], frame: Frame, name: str) -> Self: + series = self._dispatch_expr(node.input[0], frame, name) + return self._with_native(fn.cumulative(series.native, node.function), name) + + cum_count = _cumulative + cum_min = _cumulative + cum_max = _cumulative + cum_prod = _cumulative + cum_sum = _cumulative + def _is_first_last_distinct( self, node: FunctionExpr[IsFirstDistinct | IsLastDistinct], @@ -479,4 +522,11 @@ def count(self, node: Count, frame: Frame, name: str) -> Scalar: over = not_implemented() over_ordered = not_implemented() map_batches = not_implemented() + # length_preserving rolling_expr = not_implemented() + diff = not_implemented() + cum_sum = not_implemented() # TODO @dangotbanned: is this just self? + cum_count = not_implemented() + cum_min = not_implemented() + cum_max = not_implemented() + cum_prod = not_implemented() diff --git a/narwhals/_plan/arrow/functions.py b/narwhals/_plan/arrow/functions.py index 53b56d19b8..baee29c23a 100644 --- a/narwhals/_plan/arrow/functions.py +++ b/narwhals/_plan/arrow/functions.py @@ -4,7 +4,7 @@ import typing as t from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import @@ -16,7 +16,7 @@ ) from narwhals._plan import expressions as ir from narwhals._plan.arrow import options -from narwhals._plan.expressions import operators as ops +from narwhals._plan.expressions import functions as F, operators as ops from narwhals._utils import Implementation if TYPE_CHECKING: @@ -37,6 +37,7 @@ ChunkedArray, ChunkedArrayAny, ChunkedOrArrayAny, + ChunkedOrArrayT, ChunkedOrScalar, ChunkedOrScalarAny, DataType, @@ -53,7 +54,7 @@ StringType, UnaryFunction, ) - from narwhals.typing import ClosedInterval, IntoArrowSchema + from narwhals.typing import ClosedInterval, IntoArrowSchema, PythonLiteral BACKEND_VERSION = Implementation.PYARROW._backend_version() @@ -208,6 +209,70 @@ def n_unique(native: Any) -> pa.Int64Scalar: return count(native, mode="all") +def _reverse(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + """Unlike other slicing ops, `[::-1]` creates a full-copy. + + https://github.com/apache/arrow/issues/19103#issuecomment-1377671886 + """ + return native[::-1] + + +def cumulative(native: ChunkedArrayAny, cum_agg: F.CumAgg, /) -> ChunkedArrayAny: + func = _CUMULATIVE[type(cum_agg)] + if not cum_agg.reverse: + return func(native) + return _reverse(func(_reverse(native))) + + +def cum_sum(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_sum(native, skip_nulls=True) + + +def cum_min(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_min(native, skip_nulls=True) + + +def cum_max(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_max(native, skip_nulls=True) + + +def cum_prod(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + return pc.cumulative_prod(native, skip_nulls=True) + + +def cum_count(native: ChunkedArrayAny) -> ChunkedArrayAny: + return cum_sum(is_not_null(native).cast(pa.uint32())) + + +_CUMULATIVE: Mapping[type[F.CumAgg], Callable[[ChunkedArrayAny], ChunkedArrayAny]] = { + F.CumSum: cum_sum, + F.CumCount: cum_count, + F.CumMin: cum_min, + F.CumMax: cum_max, + F.CumProd: cum_prod, +} + + +def diff(native: ChunkedOrArrayT) -> ChunkedOrArrayT: + # pyarrow.lib.ArrowInvalid: Vector kernel cannot execute chunkwise and no chunked exec function was defined + return ( + pc.pairwise_diff(native) + if isinstance(native, pa.Array) + else chunked_array(pc.pairwise_diff(native.combine_chunks())) + ) + + +def shift(native: ChunkedArrayAny, n: int) -> ChunkedArrayAny: + if n == 0: + return native + arr = native + if n > 0: + arrays = [nulls_like(n, arr), *arr.slice(length=arr.length() - n).chunks] + else: + arrays = [*arr.slice(offset=-n).chunks, nulls_like(-n, arr)] + return pa.chunked_array(arrays) + + def is_between( native: ChunkedOrScalar[ScalarT], lower: ChunkedOrScalar[ScalarT], @@ -271,24 +336,45 @@ def int_range( return pa.chunked_array([pa.array(np.arange(start, end, step), dtype)]) +def nulls_like(n: int, native: ArrowAny) -> ArrayAny: + """Create a strongly-typed Array instance with all elements null. + + Uses the type of `native`. + """ + return pa.nulls(n, native.type) # type: ignore[no-any-return] + + def lit(value: Any, dtype: DataType | None = None) -> NativeScalar: return pa.scalar(value) if dtype is None else pa.scalar(value, dtype) +@overload +def array(data: ArrowAny, /) -> ArrayAny: ... +@overload +def array( + data: Iterable[PythonLiteral], dtype: DataType | None = None, / +) -> ArrayAny: ... def array( - value: NativeScalar | Iterable[Any], dtype: DataType | None = None, / + data: ArrowAny | Iterable[PythonLiteral], dtype: DataType | None = None, / ) -> ArrayAny: - return ( - pa.array([value], value.type) - if isinstance(value, pa.Scalar) - else pa.array(value, dtype) - ) + """Convert `data` into an Array instance. + + Note: + `dtype` is not used for existing `pyarrow` data, use `cast` instead. + """ + if isinstance(data, pa.ChunkedArray): + return data.combine_chunks() + if isinstance(data, pa.Array): + return data + if isinstance(data, pa.Scalar): + return pa.array([data], data.type) + return pa.array(data, dtype) def chunked_array( - arr: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, / + data: ArrowAny | list[Iterable[Any]], dtype: DataType | None = None, / ) -> ChunkedArrayAny: - return _chunked_array(array(arr) if isinstance(arr, pa.Scalar) else arr, dtype) + return _chunked_array(array(data) if isinstance(data, pa.Scalar) else data, dtype) def concat_vertical_chunked( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index df57f781c1..7b57d775a8 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -16,7 +16,7 @@ from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: - from collections.abc import Iterator, Mapping + from collections.abc import Iterator, Mapping, Sequence from typing_extensions import Self, TypeAlias @@ -138,14 +138,6 @@ def group_by_error( return InvalidOperationError(msg) -def concat_str(native: pa.Table, *, separator: str = "") -> ChunkedArray: - dtype = fn.string_type(native.schema.types) - it = fn.cast_table(native, dtype).itercolumns() - concat: Incomplete = pc.binary_join_element_wise - join = options.join_replace_nulls() - return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] - - class ArrowGroupBy(EagerDataFrameGroupBy["Frame"]): _df: Frame _keys: Seq[NamedIR] @@ -157,15 +149,12 @@ def compliant(self) -> Frame: return self._df def __iter__(self) -> Iterator[tuple[Any, Frame]]: - temp_name = temp.column_name(self.compliant) - native = self.compliant.native - composite_values = concat_str(acero.select_names_table(native, self.key_names)) - re_keyed = native.add_column(0, temp_name, composite_values) + by = self.key_names from_native = self.compliant._with_native - for v in composite_values.unique(): - t = from_native(acero.filter_table(re_keyed, pc.field(temp_name) == v)) + for partition in partition_by(self.compliant.native, by): + t = from_native(partition) yield ( - t.select_names(*self.key_names).row(0), + t.select_names(*by).row(0), t.select_names(*self._column_names_original), ) @@ -178,3 +167,59 @@ def agg(self, irs: Seq[NamedIR]) -> Frame: if original := self._key_names_original: return result.rename(dict(zip(key_names, original))) return result + + +def _composite_key(native: pa.Table, *, separator: str = "") -> ChunkedArray: + """Horizontally join columns to *seed* a unique key per row combination.""" + dtype = fn.string_type(native.schema.types) + it = fn.cast_table(native, dtype).itercolumns() + concat: Incomplete = pc.binary_join_element_wise + join = options.join_replace_nulls() + return concat(*it, fn.lit(separator, dtype), options=join) # type: ignore[no-any-return] + + +def partition_by( + native: pa.Table, by: Sequence[str], *, include_key: bool = True +) -> Iterator[pa.Table]: + if len(by) == 1: + yield from _partition_by_one(native, by[0], include_key=include_key) + else: + yield from _partition_by_many(native, by, include_key=include_key) + + +def _partition_by_one( + native: pa.Table, by: str, *, include_key: bool = True +) -> Iterator[pa.Table]: + """Optimized path for single-column partition.""" + arr_dict: Incomplete = fn.array(native.column(by).dictionary_encode("encode")) + indices: pa.Int32Array = arr_dict.indices + if not include_key: + native = native.remove_column(native.schema.get_field_index(by)) + for idx in range(len(arr_dict.dictionary)): + # NOTE: Acero filter doesn't support `null_selection_behavior="emit_null"` + # Is there any reasonable way to do this in Acero? + yield native.filter(pc.equal(pa.scalar(idx), indices)) + + +def _partition_by_many( + native: pa.Table, by: Sequence[str], *, include_key: bool = True +) -> Iterator[pa.Table]: + original_names = native.column_names + temp_name = temp.column_name(original_names) + key = acero.col(temp_name) + composite_values = _composite_key(acero.select_names_table(native, by)) + # Need to iterate over the whole thing, so py_list first should be faster + unique_py = composite_values.unique().to_pylist() + re_keyed = native.add_column(0, temp_name, composite_values) + source = acero.table_source(re_keyed) + if include_key: + keep = original_names + else: + ignore = {*by, temp_name} + keep = [name for name in original_names if name not in ignore] + select = acero.select_names(keep) + for v in unique_py: + # NOTE: May want to split the `Declaration` production iterator into it's own function + # E.g, to push down column selection to *before* collection + # Not needed for this task though + yield acero.collect(source, acero.filter(key == v), select) diff --git a/narwhals/_plan/arrow/typing.py b/narwhals/_plan/arrow/typing.py index 63333a49d4..10d2a82144 100644 --- a/narwhals/_plan/arrow/typing.py +++ b/narwhals/_plan/arrow/typing.py @@ -70,10 +70,18 @@ def __call__( def __call__( self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any ) -> ChunkedOrScalar[ScalarRT_co]: ... + @overload + def __call__( + self, data: Array[ScalarPT_contra], *args: Any, **kwds: Any + ) -> Array[ScalarRT_co]: ... + @overload + def __call__( + self, data: ChunkedOrArray[ScalarPT_contra], *args: Any, **kwds: Any + ) -> ChunkedOrArray[ScalarRT_co]: ... def __call__( - self, data: ChunkedOrScalar[ScalarPT_contra], *args: Any, **kwds: Any - ) -> ChunkedOrScalar[ScalarRT_co]: ... + self, data: Arrow[ScalarPT_contra], *args: Any, **kwds: Any + ) -> Arrow[ScalarRT_co]: ... class BinaryFunction(Protocol[ScalarPT_contra, ScalarRT_co]): @@ -130,6 +138,8 @@ class BinaryLogical(BinaryFunction["pa.BooleanScalar", "pa.BooleanScalar"], Prot ChunkedArrayAny: TypeAlias = "ChunkedArray[Any]" ChunkedOrScalarAny: TypeAlias = "ChunkedOrScalar[ScalarAny]" ChunkedOrArrayAny: TypeAlias = "ChunkedOrArray[ScalarAny]" +ChunkedOrArrayT = TypeVar("ChunkedOrArrayT", ChunkedArrayAny, ArrayAny) +Arrow: TypeAlias = "ChunkedOrScalar[ScalarT_co] | Array[ScalarT_co]" ArrowAny: TypeAlias = "ChunkedOrScalarAny | ArrayAny" NativeScalar: TypeAlias = ScalarAny BinOp: TypeAlias = Callable[..., ChunkedOrScalarAny] diff --git a/narwhals/_plan/compliant/dataframe.py b/narwhals/_plan/compliant/dataframe.py index 45728fce47..7334eef8af 100644 --- a/narwhals/_plan/compliant/dataframe.py +++ b/narwhals/_plan/compliant/dataframe.py @@ -129,6 +129,9 @@ def join( suffix: str = "_right", ) -> Self: ... def join_cross(self, other: Self, *, suffix: str = "_right") -> Self: ... + def partition_by( + self, by: Sequence[str], *, include_key: bool = True + ) -> list[Self]: ... def row(self, index: int) -> tuple[Any, ...]: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, SeriesT]: ... diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index e8d0dffbed..f60ae2d729 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -92,6 +92,9 @@ def pow(self, node: FunctionExpr[F.Pow], frame: FrameT_contra, name: str) -> Sel def rolling_expr( self, node: ir.RollingExpr, frame: FrameT_contra, name: str ) -> Self: ... + def shift( + self, node: FunctionExpr[F.Shift], frame: FrameT_contra, name: str + ) -> Self: ... def ternary_expr( self, node: ir.TernaryExpr, frame: FrameT_contra, name: str ) -> Self: ... @@ -99,6 +102,24 @@ def ternary_expr( def filter(self, node: ir.Filter, frame: FrameT_contra, name: str) -> Self: ... def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: ... def sort_by(self, node: ir.SortBy, frame: FrameT_contra, name: str) -> Self: ... + def diff( + self, node: FunctionExpr[F.Diff], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_count( + self, node: FunctionExpr[F.CumCount], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_min( + self, node: FunctionExpr[F.CumMin], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_max( + self, node: FunctionExpr[F.CumMax], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_prod( + self, node: FunctionExpr[F.CumProd], frame: FrameT_contra, name: str + ) -> Self: ... + def cum_sum( + self, node: FunctionExpr[F.CumSum], frame: FrameT_contra, name: str + ) -> Self: ... # series -> scalar def all( self, node: FunctionExpr[boolean.All], frame: FrameT_contra, name: str diff --git a/narwhals/_plan/compliant/group_by.py b/narwhals/_plan/compliant/group_by.py index 8e05144393..adac2bb402 100644 --- a/narwhals/_plan/compliant/group_by.py +++ b/narwhals/_plan/compliant/group_by.py @@ -12,7 +12,7 @@ FrameT_co, ResolverT_co, ) -from narwhals.exceptions import ComputeError +from narwhals._plan.exceptions import group_by_no_keys_error if TYPE_CHECKING: from collections.abc import Iterator @@ -51,8 +51,7 @@ def keys(self) -> Seq[NamedIR]: def key_names(self) -> Seq[str]: if names := self._key_names: return names - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) + raise group_by_no_keys_error() class EagerDataFrameGroupBy(DataFrameGroupBy[EagerDataFrameT], Protocol[EagerDataFrameT]): @@ -163,8 +162,7 @@ def key_names(self) -> Seq[str]: return names if keys := self.keys: return tuple(e.name for e in keys) - msg = "at least one key is required in a group_by operation" - raise ComputeError(msg) + raise group_by_no_keys_error() def requires_projection(self, *, allow_aliasing: bool = False) -> bool: """Return True is group keys contain anything that is not a column selection. @@ -203,3 +201,13 @@ class Grouped(Grouper[Resolved]): @property def _resolver(self) -> type[Resolved]: return Resolved + + @classmethod + def by_irs(cls, *by: ExprIR) -> Self: + obj = cls.__new__(cls) + obj._keys = by + return obj + + def agg_irs(self, *aggs: ExprIR) -> Self: + self._aggs = aggs + return self diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index 25c07d7de7..65bd920f8a 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -11,7 +11,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, aggregation as agg from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct - from narwhals._plan.expressions.functions import EwmMean + from narwhals._plan.expressions.functions import EwmMean, Shift from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -101,6 +101,11 @@ def n_unique(self, node: agg.NUnique, frame: FrameT_contra, name: str) -> Self: def quantile(self, node: agg.Quantile, frame: FrameT_contra, name: str) -> Self: return self._cast_float(node.expr, frame, name) + def shift(self, node: FunctionExpr[Shift], frame: FrameT_contra, name: str) -> Self: + if node.function.n == 0: + return self._with_evaluated(self._evaluated, name) + return self.from_python(None, name, dtype=None, version=self.version) + def sort(self, node: ir.Sort, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index 625f1990e0..db2d43b5c1 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -3,8 +3,9 @@ from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, get_args, overload from narwhals._plan import _parse -from narwhals._plan._expansion import prepare_projection +from narwhals._plan._expansion import expand_selector_irs_names, prepare_projection from narwhals._plan.common import ensure_seq_str, temp +from narwhals._plan.exceptions import group_by_no_keys_error from narwhals._plan.group_by import GroupBy, Grouped from narwhals._plan.options import SortMultipleOptions from narwhals._plan.series import Series @@ -258,6 +259,19 @@ def filter( raise ValueError(msg) return self._with_compliant(self._compliant.filter(named_irs[0])) + def partition_by( + self, + by: OneOrIterable[ColumnNameOrSelector], + *more_by: ColumnNameOrSelector, + include_key: bool = True, + ) -> list[Self]: + by_selectors = _parse.parse_into_seq_of_selector_ir(by, *more_by) + names = expand_selector_irs_names(by_selectors, schema=self) + if not names: + raise group_by_no_keys_error() + partitions = self._compliant.partition_by(names, include_key=include_key) + return [self._with_compliant(p) for p in partitions] + def _is_join_strategy(obj: Any) -> TypeIs[JoinStrategy]: return obj in {"inner", "left", "full", "cross", "anti", "semi"} diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index cfeb87644b..a1c20b995c 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -210,3 +210,8 @@ def column_index_error( max_nth = f"`nth({n_names - 1})`" if index >= 0 else f"`nth(-{n_names})`" msg = f"Invalid column index {index!r}\nHint: The schema's last column is {max_nth}" return ComputeError(msg) + + +def group_by_no_keys_error() -> ComputeError: + msg = "at least one key is required in a group_by operation" + return ComputeError(msg) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 4fa2f6cf6e..073c115305 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -9,6 +9,7 @@ from narwhals._plan._expr_ir import ExprIR, SelectorIR from narwhals._plan.common import flatten_hash_safe from narwhals._plan.exceptions import function_expr_invalid_operation_error +from narwhals._plan.expressions import selectors as cs from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ( FunctionT_co, @@ -32,7 +33,6 @@ from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expressions.functions import MapBatches # noqa: F401 from narwhals._plan.expressions.literal import LiteralValue - from narwhals._plan.expressions.selectors import Selector from narwhals._plan.expressions.window import Window from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions from narwhals.dtypes import DType @@ -100,6 +100,9 @@ class Column(ExprIR, config=ExprIROptions.namespaced("col")): def __repr__(self) -> str: return f"col({self.name!r})" + def to_selector_ir(self) -> RootSelector: + return cs.ByName.from_name(self.name).to_selector_ir() + class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()): """Nodes which can resolve to `Column`(s) with a `Schema`.""" @@ -112,7 +115,11 @@ class Columns(_ColumnSelection): def __repr__(self) -> str: return f"cols({list(self.names)!r})" + def to_selector_ir(self) -> RootSelector: + return cs.ByName.from_names(*self.names).to_selector_ir() + +# TODO @dangotbanned: Add `selectors.by_index` class Nth(_ColumnSelection): __slots__ = ("index",) index: int @@ -121,6 +128,7 @@ def __repr__(self) -> str: return f"nth({self.index})" +# TODO @dangotbanned: Add `selectors.by_index` class IndexColumns(_ColumnSelection): __slots__ = ("indices",) indices: Seq[int] @@ -133,7 +141,11 @@ class All(_ColumnSelection): def __repr__(self) -> str: return "all()" + def to_selector_ir(self) -> RootSelector: + return cs.All().to_selector_ir() + +# TODO @dangotbanned: Add `selectors.exclude` class Exclude(_ColumnSelection, child=("expr",)): __slots__ = ("expr", "names") expr: ExprIR @@ -450,7 +462,7 @@ class RootSelector(SelectorIR): """A single selector expression.""" __slots__ = ("selector",) - selector: Selector + selector: cs.Selector def __repr__(self) -> str: return f"{self.selector!r}" diff --git a/narwhals/_plan/expressions/selectors.py b/narwhals/_plan/expressions/selectors.py index d09d7af761..59aa66a96c 100644 --- a/narwhals/_plan/expressions/selectors.py +++ b/narwhals/_plan/expressions/selectors.py @@ -6,30 +6,39 @@ from __future__ import annotations +import builtins +import operator import re +from collections import deque +from functools import reduce from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable from narwhals._plan.common import flatten_hash_safe from narwhals._utils import Version, _parse_time_unit_and_time_zone +from narwhals.typing import TimeUnit if TYPE_CHECKING: + from collections.abc import Callable, Mapping from datetime import timezone from typing import TypeVar from narwhals._plan import expr + from narwhals._plan.expressions import SelectorIR from narwhals._plan.expressions.expr import RootSelector from narwhals._plan.typing import OneOrIterable from narwhals.dtypes import DType - from narwhals.typing import TimeUnit T = TypeVar("T") -dtypes = Version.MAIN.dtypes +_dtypes = Version.MAIN.dtypes +_dtypes_v1 = Version.V1.dtypes + +_ALL_TIME_UNITS = frozenset[TimeUnit](("ms", "us", "ns", "s")) class Selector(Immutable): - def to_selector(self) -> RootSelector: + def to_selector_ir(self) -> RootSelector: from narwhals._plan.expressions.expr import RootSelector return RootSelector(selector=self) @@ -46,14 +55,37 @@ def matches_column(self, name: str, dtype: DType) -> bool: return True +class Array(Selector): + __slots__ = ("inner", "size") + inner: SelectorIR | None + size: int | None + """Not sure why polars is using the (`0.20.31`) deprecated name `width`.""" + + def __repr__(self) -> str: + inner = "" if not self.inner else repr(self.inner) + size = self.size or "*" + return f"ncs.array({inner}, size={size})" + + def matches_column(self, name: str, dtype: DType) -> bool: + return ( + isinstance(dtype, _dtypes.Array) + and (not (self.inner) or self.inner.matches_column(name, dtype)) + and (self.size is None or dtype.size == self.size) + ) + + +class Boolean(Selector): + def __repr__(self) -> str: + return "ncs.boolean()" + + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, _dtypes.Boolean) + + class ByDType(Selector): __slots__ = ("dtypes",) dtypes: frozenset[DType | type[DType]] - @staticmethod - def from_dtypes(*dtypes: OneOrIterable[DType | type[DType]]) -> ByDType: - return ByDType(dtypes=frozenset(flatten_hash_safe(dtypes))) - def __repr__(self) -> str: els = ", ".join( tp.__name__ if isinstance(tp, type) else repr(tp) for tp in self.dtypes @@ -64,12 +96,28 @@ def matches_column(self, name: str, dtype: DType) -> bool: return dtype in self.dtypes -class Boolean(Selector): +class ByName(Selector): + # NOTE: `polars` allows this and `by_index` to redefine schema order in a `select` + # > Matching columns are returned in the order in which they are declared in + # > the selector, not the underlying schema order. + # If you wanna support that (later), then a `frozenset` won't work + __slots__ = ("names",) + names: frozenset[str] + def __repr__(self) -> str: - return "ncs.boolean()" + els = ", ".join(f"{nm!r}" for nm in sorted(self.names)) + return f"ncs.by_name({els})" + + @staticmethod + def from_names(*names: OneOrIterable[str]) -> ByName: + return ByName(names=frozenset(flatten_hash_safe(names))) + + @staticmethod + def from_name(name: str, /) -> ByName: + return ByName(names=frozenset((name,))) def matches_column(self, name: str, dtype: DType) -> bool: - return isinstance(dtype, dtypes.Boolean) + return name in self.names class Categorical(Selector): @@ -77,7 +125,7 @@ def __repr__(self) -> str: return "ncs.categorical()" def matches_column(self, name: str, dtype: DType) -> bool: - return isinstance(dtype, dtypes.Categorical) + return isinstance(dtype, _dtypes.Categorical) class Datetime(Selector): @@ -102,12 +150,12 @@ def from_time_unit_and_time_zone( return Datetime(time_units=frozenset(units), time_zones=frozenset(zones)) def __repr__(self) -> str: - return f"ncs.datetime(time_unit={list(self.time_units)}, time_zone={list(self.time_zones)})" + return f"ncs.datetime(time_unit={builtins.list(self.time_units)}, time_zone={builtins.list(self.time_zones)})" def matches_column(self, name: str, dtype: DType) -> bool: units, zones = self.time_units, self.time_zones return ( - isinstance(dtype, dtypes.Datetime) + isinstance(dtype, _dtypes.Datetime) and (dtype.time_unit in units) and ( dtype.time_zone in zones or ("*" in zones and dtype.time_zone is not None) @@ -115,6 +163,51 @@ def matches_column(self, name: str, dtype: DType) -> bool: ) +class Duration(Selector): + __slots__ = ("time_units",) + time_units: frozenset[TimeUnit] + + @staticmethod + def from_time_unit(time_unit: OneOrIterable[TimeUnit] | None, /) -> Duration: + if time_unit is None: + units = _ALL_TIME_UNITS + elif not isinstance(time_unit, str): + units = frozenset(time_unit) + else: + units = frozenset((time_unit,)) + return Duration(time_units=units) + + def __repr__(self) -> str: + return f"ncs.duration(time_unit={builtins.list(self.time_units)})" + + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, _dtypes.Duration) and ( + dtype.time_unit in self.time_units + ) + + +class Enum(Selector): + def __repr__(self) -> str: + return "ncs.enum()" + + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, _dtypes.Enum) + + +class List(Selector): + __slots__ = ("inner",) + inner: SelectorIR | None + + def __repr__(self) -> str: + inner = "" if not self.inner else repr(self.inner) + return f"ncs.list({inner})" + + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, _dtypes.List) and ( + not (self.inner) or self.inner.matches_column(name, dtype) + ) + + class Matches(Selector): __slots__ = ("pattern",) pattern: re.Pattern[str] @@ -123,12 +216,6 @@ class Matches(Selector): def from_string(pattern: str, /) -> Matches: return Matches(pattern=re.compile(pattern)) - @staticmethod - def from_names(*names: OneOrIterable[str]) -> Matches: - """Implements `cs.by_name` to support `__r__` with column selections.""" - it = flatten_hash_safe(names) - return Matches.from_string(f"^({'|'.join(re.escape(name) for name in it)})$") - def __repr__(self) -> str: return f"ncs.matches(pattern={self.pattern.pattern!r})" @@ -149,27 +236,46 @@ def __repr__(self) -> str: return "ncs.string()" def matches_column(self, name: str, dtype: DType) -> bool: - return isinstance(dtype, dtypes.String) + return isinstance(dtype, _dtypes.String) + + +class Struct(Selector): + def __repr__(self) -> str: + return "ncs.struct()" + + def matches_column(self, name: str, dtype: DType) -> bool: + return isinstance(dtype, _dtypes.Struct) def all() -> expr.Selector: - return All().to_selector().to_narwhals() + return All().to_selector_ir().to_narwhals() + + +def array( + inner: expr.Selector | None = None, *, size: int | None = None +) -> expr.Selector: + s_ir = inner._ir if inner is not None else None + return Array(inner=s_ir, size=size).to_selector_ir().to_narwhals() def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selector: - return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals() + return _from_dtypes(*dtypes) def by_name(*names: OneOrIterable[str]) -> expr.Selector: - return Matches.from_names(*names).to_selector().to_narwhals() + if len(names) == 1 and isinstance(names[0], str): + sel = ByName.from_name(names[0]) + else: + sel = ByName.from_names(*names) + return sel.to_selector_ir().to_narwhals() def boolean() -> expr.Selector: - return Boolean().to_selector().to_narwhals() + return Boolean().to_selector_ir().to_narwhals() def categorical() -> expr.Selector: - return Categorical().to_selector().to_narwhals() + return Categorical().to_selector_ir().to_narwhals() def datetime( @@ -178,18 +284,66 @@ def datetime( ) -> expr.Selector: return ( Datetime.from_time_unit_and_time_zone(time_unit, time_zone) - .to_selector() + .to_selector_ir() .to_narwhals() ) +def list(inner: expr.Selector | None = None) -> expr.Selector: + s_ir = inner._ir if inner is not None else None + return List(inner=s_ir).to_selector_ir().to_narwhals() + + +def duration(time_unit: OneOrIterable[TimeUnit] | None = None) -> expr.Selector: + return Duration.from_time_unit(time_unit).to_selector_ir().to_narwhals() + + +def enum() -> expr.Selector: + return Enum().to_selector_ir().to_narwhals() + + def matches(pattern: str) -> expr.Selector: - return Matches.from_string(pattern).to_selector().to_narwhals() + return Matches.from_string(pattern).to_selector_ir().to_narwhals() def numeric() -> expr.Selector: - return Numeric().to_selector().to_narwhals() + return Numeric().to_selector_ir().to_narwhals() def string() -> expr.Selector: - return String().to_selector().to_narwhals() + return String().to_selector_ir().to_narwhals() + + +def struct() -> expr.Selector: + return Struct().to_selector_ir().to_narwhals() + + +_HASH_SENSITIVE_TO_SELECTOR: Mapping[type[DType], Callable[[], expr.Selector]] = { + _dtypes.Datetime: datetime, + _dtypes_v1.Datetime: datetime, + _dtypes.Duration: duration, + _dtypes_v1.Duration: duration, + _dtypes.Enum: enum, + _dtypes_v1.Enum: enum, + _dtypes.Array: array, + _dtypes.List: list, + _dtypes.Struct: struct, +} + + +def _from_dtypes(*by_dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selector: + selectors: deque[expr.Selector] = deque() + dtypes: deque[DType | type[DType]] = deque() + for dtype in flatten_hash_safe(by_dtypes): + if isinstance(dtype, type): + if constructor := _HASH_SENSITIVE_TO_SELECTOR.get(dtype): + selectors.append(constructor()) + else: + dtypes.append(dtype) + else: + dtypes.append(dtype) # type: ignore[arg-type] + if dtypes: + dtype_selector = ByDType(dtypes=frozenset(dtypes)).to_selector_ir().to_narwhals() + selectors.appendleft(dtype_selector) + it = iter(selectors) + return reduce(operator.or_, it, next(it)) diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index bb7a4315b3..6e67dad5cc 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -14,12 +14,14 @@ from narwhals._plan._guards import is_literal from narwhals._plan.expressions.literal import is_literal_scalar from narwhals._plan.expressions.namespace import IRNamespace -from narwhals.exceptions import ComputeError +from narwhals.exceptions import ComputeError, InvalidOperationError from narwhals.utils import Version if TYPE_CHECKING: from collections.abc import Iterable, Iterator + from narwhals._plan import expr + class MetaNamespace(IRNamespace): """Methods to modify and traverse existing expressions.""" @@ -75,6 +77,16 @@ def root_names(self) -> list[str]: """Get the root column names.""" return list(_expr_to_leaf_column_names_iter(self._ir)) + def as_selector(self) -> expr.Selector: + """Try to turn this expression into a selector. + + Raises if the underlying expressions is not a column or selector. + """ + if not self.is_column_selection(): + msg = f"cannot turn `{self._ir!r}` into a selector" + raise InvalidOperationError(msg) + return self._ir.to_selector_ir().to_narwhals() + def _expr_to_leaf_column_names_iter(expr: ir.ExprIR, /) -> Iterator[str]: for e in _expr_to_leaf_column_exprs_iter(expr): diff --git a/tests/plan/expr_expansion_test.py b/tests/plan/expr_expansion_test.py index 203c39911b..2a993315e8 100644 --- a/tests/plan/expr_expansion_test.py +++ b/tests/plan/expr_expansion_test.py @@ -243,6 +243,13 @@ def test_map_ir_recursive(expr: nwp.Expr, function: MapIR, expected: nwp.Expr) - .name.to_uppercase() ), ), + pytest.param( + ndcs.by_dtype( + nw.Datetime, nw.Enum, nw.Duration, nw.Struct, nw.List, nw.Array + ), + nwp.col("l", "o", "q", "r", "s", "u"), + id="ByDType-isinstance", + ), ], ) def test_replace_selector( @@ -440,6 +447,35 @@ def test_replace_selector( ], id="Selector-BinaryExpr-Over-Prefix", ), + pytest.param( + [ + nwp.col("c").sort_by(nwp.col("c", "i")).first().alias("Columns"), + nwp.col("c").sort_by("c", "i").first().alias("Column_x2"), + ], + [ + named_ir( + "Columns", nwp.col("c").sort_by(nwp.col("c"), nwp.col("i")).first() + ), + named_ir( + "Column_x2", nwp.col("c").sort_by(nwp.col("c"), nwp.col("i")).first() + ), + ], + id="SortBy-Columns", + ), + pytest.param( + nwp.nth(1).mean().over("k", order_by=nwp.nth(4, 5)), + [ + nwp.col("b") + .mean() + .over(nwp.col("k"), order_by=(nwp.col("e"), nwp.col("f"))) + ], + id="Over-OrderBy-IndexColumns", + ), + pytest.param( + nwp.col("f").max().over(ndcs.by_dtype(nw.Date, nw.Datetime)), + [nwp.col("f").max().over(nwp.col("l"), nwp.col("n"), nwp.col("o"))], + id="Over-Partitioned-Selector", + ), ], ) def test_prepare_projection( diff --git a/tests/plan/frame_partition_by_test.py b/tests/plan/frame_partition_by_test.py new file mode 100644 index 0000000000..b17bf9ab42 --- /dev/null +++ b/tests/plan/frame_partition_by_test.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +import narwhals as nw +from narwhals._plan import Selector, selectors as ncs +from narwhals._utils import zip_strict +from narwhals.exceptions import ColumnNotFoundError, ComputeError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return { + "a": ["a", "b", "a", None, "b", "c"], + "b": [1, 2, 1, 5, 3, 3], + "c": [5, 4, 3, 6, 2, 1], + } + + +@pytest.mark.parametrize( + ("include_key", "expected"), + [ + ( + True, + [ + {"a": ["a", "a"], "b": [1, 1], "c": [5, 3]}, + {"a": ["b", "b"], "b": [2, 3], "c": [4, 2]}, + {"a": [None], "b": [5], "c": [6]}, + {"a": ["c"], "b": [3], "c": [1]}, + ], + ), + ( + False, + [ + {"b": [1, 1], "c": [5, 3]}, + {"b": [2, 3], "c": [4, 2]}, + {"b": [5], "c": [6]}, + {"b": [3], "c": [1]}, + ], + ), + ], + ids=["include_key", "exclude_key"], +) +@pytest.mark.parametrize( + "by", + ["a", ncs.string(), ncs.matches("a"), ncs.by_name("a"), ncs.by_dtype(nw.String)], + ids=["str", "ncs.string", "ncs.matches", "ncs.by_name", "ncs.by_dtype"], +) +def test_partition_by_single( + data: Data, by: ColumnNameOrSelector, *, include_key: bool, expected: Any +) -> None: + df = dataframe(data) + results = df.partition_by(by, include_key=include_key) + for df, expect in zip_strict(results, expected): + assert_equal_data(df, expect) + + +@pytest.mark.parametrize( + ("include_key", "expected"), + [ + ( + True, + [ + {"a": ["a", "a"], "b": [1, 1], "c": [5, 3]}, + {"a": ["b"], "b": [2], "c": [4]}, + {"a": [None], "b": [5], "c": [6]}, + {"a": ["b"], "b": [3], "c": [2]}, + {"a": ["c"], "b": [3], "c": [1]}, + ], + ), + (False, [{"c": [5, 3]}, {"c": [4]}, {"c": [6]}, {"c": [2]}, {"c": [1]}]), + ], + ids=["include_key", "exclude_key"], +) +@pytest.mark.parametrize( + ("by", "more_by"), + [ + ("a", "b"), + (["a", "b"], ()), + (ncs.matches("a|b"), ()), + (ncs.string(), "b"), + (ncs.by_name("a", "b"), ()), + (ncs.by_name("b"), ncs.by_name("a")), + (ncs.by_dtype(nw.String) | (ncs.numeric() - ncs.by_name("c")), []), + ], + ids=[ + "str-variadic", + "str-list", + "ncs.matches", + "ncs.string-str", + "ncs.by_name", + "2x-selector", + "BinarySelector", + ], +) +def test_partition_by_multiple( + data: Data, + by: ColumnNameOrSelector, + more_by: OneOrIterable[ColumnNameOrSelector], + *, + include_key: bool, + expected: Any, +) -> None: + df = dataframe(data) + if isinstance(more_by, (str, Selector)): + results = df.partition_by(by, more_by, include_key=include_key) + else: + results = df.partition_by(by, *more_by, include_key=include_key) + for df, expect in zip_strict(results, expected): + assert_equal_data(df, expect) + + +# TODO @dangotbanned: Stricter selectors +@pytest.mark.xfail( + reason="TODO: Handle missing columns in `strict`/`require_all` selectors." +) +def test_partition_by_missing_names(data: Data) -> None: # pragma: no cover + df = dataframe(data) + with pytest.raises(ColumnNotFoundError, match=r"\"d\""): + df.partition_by("d") + with pytest.raises(ColumnNotFoundError, match=r"\"e\""): + df.partition_by("c", "e") + + +def test_partition_by_fully_empty_selector(data: Data) -> None: + df = dataframe(data) + with pytest.raises( + ComputeError, match=r"at least one key is required in a group_by operation" + ): + df.partition_by(ncs.array(ncs.numeric()), ncs.struct(), ncs.duration()) + + +# NOTE: Matching polars behavior +def test_partition_by_partially_missing_selector(data: Data) -> None: + df = dataframe(data) + results = df.partition_by(ncs.string() | ncs.list() | ncs.enum()) + expected = nw.Schema({"a": nw.String(), "b": nw.Int64(), "c": nw.Int64()}) + for df in results: + assert df.schema == expected diff --git a/tests/plan/over_test.py b/tests/plan/over_test.py new file mode 100644 index 0000000000..6f425ebabe --- /dev/null +++ b/tests/plan/over_test.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +pytest.importorskip("pyarrow") + +import narwhals as nw +import narwhals._plan as nwp +from narwhals._plan import selectors as ncs +from narwhals.exceptions import InvalidOperationError +from tests.plan.utils import assert_equal_data, dataframe + +if TYPE_CHECKING: + from narwhals._plan.typing import IntoExprColumn, OneOrIterable + from tests.conftest import Data + + +@pytest.fixture +def data() -> Data: + return { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + } + + +@pytest.fixture +def data_with_null(data: Data) -> Data: + return data | {"b": [1, 2, None, 5, 3]} + + +@pytest.fixture +def data_alt() -> Data: + return {"a": [3, 5, 1, 2, None], "b": [0, 1, 3, 2, 1], "c": [9, 1, 2, 1, 1]} + + +XFAIL_REQUIRES_PARTITION_BY = pytest.mark.xfail( + reason="Native group_by isn't enough", raises=InvalidOperationError +) + + +@pytest.mark.parametrize( + "partition_by", + [ + "a", + ["a"], + nwp.nth(0), + ncs.string(), + ncs.by_dtype(nw.String), + ncs.by_name("a"), + ncs.matches(r"a"), + ncs.all() - ncs.numeric(), + ], +) +def test_over_single(data: Data, partition_by: OneOrIterable[IntoExprColumn]) -> None: + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + "c_max": [5, 5, 3, 3, 3], + } + result = ( + dataframe(data) + .with_columns(c_max=nwp.col("c").max().over(partition_by)) + .sort("i") + ) + assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + "partition_by", + [ + ("a", "b"), + [nwp.col("a"), nwp.col("b")], + [nwp.nth(0), nwp.nth(1)], + nwp.col("a", "b"), + nwp.nth(0, 1), + ncs.by_name("a", "b"), + ncs.matches(r"a|b"), + ncs.all() - ncs.by_name(["c", "i"]), + ], + ids=[ + "tuple[str]", + "col-col", + "nth-nth", + "cols", + "index_columns", + "by_name", + "matches", + "binary_selector", + ], +) +def test_over_multiple(data: Data, partition_by: OneOrIterable[IntoExprColumn]) -> None: + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + "c_min": [5, 4, 1, 2, 1], + } + result = ( + dataframe(data) + .with_columns(c_min=nwp.col("c").min().over(partition_by)) + .sort("i") + ) + assert_equal_data(result, expected) + + +@XFAIL_REQUIRES_PARTITION_BY +def test_over_cum_sum(data_with_null: Data) -> None: # pragma: no cover + df = dataframe(data_with_null) + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, None, 5, 3], + "c": [5, 4, 3, 2, 1], + "b_cum_sum": [1, 3, None, 5, 8], + "c_cum_sum": [5, 9, 3, 5, 6], + } + + result = ( + df.with_columns(nwp.col("b", "c").cum_sum().over("a").name.suffix("_cum_sum")) + .sort("i") + .drop("i") + ) + assert_equal_data(result, expected) + + +def test_over_std_var(data: Data) -> None: + expected = { + "a": ["a", "a", "b", "b", "b"], + "b": [1, 2, 3, 5, 3], + "c": [5, 4, 3, 2, 1], + "i": [0, 1, 2, 3, 4], + "c_std0": [0.5, 0.5, 0.816496580927726, 0.816496580927726, 0.816496580927726], + "c_std1": [0.7071067811865476, 0.7071067811865476, 1.0, 1.0, 1.0], + "c_var0": [ + 0.25, + 0.25, + 0.6666666666666666, + 0.6666666666666666, + 0.6666666666666666, + ], + "c_var1": [0.5, 0.5, 1.0, 1.0, 1.0], + } + + result = ( + dataframe(data) + .with_columns( + c_std0=nwp.col("c").std(ddof=0).over("a"), + c_std1=nwp.col("c").std(ddof=1).over("a"), + c_var0=nwp.col("c").var(ddof=0).over("a"), + c_var1=nwp.col("c").var(ddof=1).over("a"), + ) + .sort("i") + ) + assert_equal_data(result, expected) + + +# NOTE: Supporting this for pyarrow is new 🥳 +def test_over_anonymous_reduction() -> None: + df = dataframe({"a": [1, 1, 2], "b": [4, 5, 6]}) + result = df.with_columns(nwp.all().sum().over("a").name.suffix("_sum")).sort("a", "b") + expected = {"a": [1, 1, 2], "b": [4, 5, 6], "a_sum": [2, 2, 2], "b_sum": [9, 9, 6]} + assert_equal_data(result, expected) + + +def test_over_raise_len_change(data: Data) -> None: + df = dataframe(data) + with pytest.raises(InvalidOperationError): + df.select(nwp.col("b").drop_nulls().over("a")) + + +# NOTE: Slightly different error, but same reason for raising +# (expr-ir): InvalidOperationError: `cum_sum()` is not supported in a `group_by` context +# (main): NotImplementedError: Only aggregation or literal operations are supported in grouped `over` context for PyArrow. +# https://github.com/narwhals-dev/narwhals/blob/ecde261d799a711c2e0a7acf11b108bc45035dc9/narwhals/_arrow/expr.py#L116-L118 +def test_unsupported_over(data: Data) -> None: + df = dataframe(data) + with pytest.raises(InvalidOperationError): + df.select(nwp.col("a").shift(1).cum_sum().over("b")) + + +def test_over_without_partition_by() -> None: + df = dataframe({"a": [1, -1, 2], "i": [0, 2, 1]}) + result = ( + df.with_columns(b=nwp.col("a").abs().cum_sum().over(order_by="i")) + .sort("i") + .select("a", "b", "i") + ) + expected = {"a": [1, 2, -1], "b": [1, 3, 4], "i": [0, 1, 2]} + assert_equal_data(result, expected) + + +def test_aggregation_over_without_partition_by() -> None: + df = dataframe({"a": [1, -1, 2], "i": [0, 2, 1]}) + result = ( + df.with_columns(b=nwp.col("a").diff().sum().over(order_by="i")) + .sort("i") + .select("a", "b", "i") + ) + expected = {"a": [1, 2, -1], "b": [-2, -2, -2], "i": [0, 1, 2]} + assert_equal_data(result, expected) + + +def test_len_over_2369() -> None: + df = dataframe({"a": [1, 2, 4], "b": ["x", "x", "y"]}) + result = df.with_columns(a_len_per_group=nwp.len().over("b")).sort("a") + expected = {"a": [1, 2, 4], "b": ["x", "x", "y"], "a_len_per_group": [2, 2, 1]} + assert_equal_data(result, expected) + + +def test_shift_kitchen_sink(data_alt: Data) -> None: + result = dataframe(data_alt).select( + nwp.nth(1, 2) + .shift(-1) + .over(order_by=nwp.nth(0)) + .sort(nulls_last=True) + .fill_null(100) + * 5 + ) + expected = {"b": [0, 5, 10, 15, 500], "c": [5, 5, 10, 45, 500]} + assert_equal_data(result, expected)