diff --git a/narwhals/_plan/_expansion.py b/narwhals/_plan/_expansion.py index dcb6bd25dc..480bd5d7d2 100644 --- a/narwhals/_plan/_expansion.py +++ b/narwhals/_plan/_expansion.py @@ -118,9 +118,11 @@ def expand_selector_irs_names( ignored: Names of `group_by` columns. schema: Scope to expand selectors in. """ - expander = Expander(schema, ignored) - names = expander.iter_expand_selector_names(selectors) - return _ensure_valid_output_names(tuple(names), expander.schema) + names = tuple(Expander(schema, ignored).iter_expand_selector_names(selectors)) + if len(names) != len(set(names)): + # NOTE: Can't easily reuse `duplicate_error`, falling back to main for now + check_column_names_are_unique(names) + return names def remove_alias(origin: ExprIR, /) -> ExprIR: @@ -139,14 +141,6 @@ def fn(child: ExprIR, /) -> ExprIR: return origin.map_ir(fn) -def _ensure_valid_output_names(names: Seq[str], schema: FrozenSchema) -> OutputNames: - check_column_names_are_unique(names) - output_names = names - if not (set(schema.names).issuperset(output_names)): - raise column_not_found_error(output_names, schema) - return output_names - - class Expander: __slots__ = ("ignored", "schema") schema: FrozenSchema @@ -223,7 +217,7 @@ def _expand_recursive(self, origin: ExprIR, /) -> Iterator[ExprIR]: yield from self._expand_function_expr(origin) else: msg = f"Didn't expect to see {type(origin).__name__}" - raise TypeError(msg) + raise NotImplementedError(msg) def _expand_inner(self, children: Seq[ExprIR], /) -> Iterator[ExprIR]: """Use when we want to expand non-root nodes, *without* duplicating the root. @@ -265,8 +259,8 @@ def _expand_only(self, child: ExprIR, /) -> ExprIR: iterable = self._expand_recursive(child) first = next(iterable) if second := next(iterable, None): - msg = f"Multi-output expressions are not supported in this context, got: `{second!r}`" - raise MultiOutputExpressionError(msg) + msg = f"Multi-output expressions are not supported in this context, got: `{second!r}`" # pragma: no cover + raise MultiOutputExpressionError(msg) # pragma: no cover return first # TODO @dangotbanned: It works, but all this class-specific branching belongs in the classes themselves diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 4167edca62..18f670cb6f 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -152,7 +152,7 @@ def iter_right(self) -> Iterator[ExprIR]: if isinstance(child, ExprIR): yield from child.iter_right() else: - for node in reversed(child): + for node in reversed(child): # pragma: no cover yield from node.iter_right() def iter_root_names(self) -> Iterator[ExprIR]: @@ -199,9 +199,8 @@ class SelectorIR(ExprIR, config=ExprIROptions.no_dispatch()): def to_narwhals(self, version: Version = Version.MAIN) -> Selector: from narwhals._plan.selectors import Selector, SelectorV1 - if version is Version.MAIN: - return Selector._from_ir(self) - return SelectorV1._from_ir(self) + tp = Selector if version is Version.MAIN else SelectorV1 + return tp._from_ir(self) def into_columns( self, schema: FrozenSchema, ignored_columns: Container[str] @@ -267,10 +266,10 @@ def map_ir(self, function: MapIR, /) -> Self: def __repr__(self) -> str: return f"{self.name}={self.expr!r}" - def _repr_html_(self) -> str: + def _repr_html_(self) -> str: # pragma: no cover return f"{self.name}={self.expr._repr_html_()}" - def is_elementwise_top_level(self) -> bool: + def is_elementwise_top_level(self) -> bool: # pragma: no cover """Return True if the outermost node is elementwise. Based on [`polars_plan::plans::aexpr::properties::AExpr.is_elementwise_top_level`] diff --git a/narwhals/_plan/_guards.py b/narwhals/_plan/_guards.py index f0300867f3..6be5ed35d3 100644 --- a/narwhals/_plan/_guards.py +++ b/narwhals/_plan/_guards.py @@ -132,9 +132,7 @@ def is_literal(obj: Any) -> TypeIs[ir.Literal[Any]]: return isinstance(obj, _ir().Literal) -def is_horizontal_reduction(obj: Any) -> TypeIs[ir.FunctionExpr[Any]]: - return is_function_expr(obj) and obj.options.is_input_wildcard_expansion() - - -def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: +# TODO @dangotbanned: Coverage +# Used in `ArrowNamespace._vertical`, but only horizontal is covered +def is_tuple_of(obj: Any, tp: type[T]) -> TypeIs[Seq[T]]: # pragma: no cover return bool(isinstance(obj, tuple) and obj and isinstance(obj[0], tp)) diff --git a/narwhals/_plan/_immutable.py b/narwhals/_plan/_immutable.py index 09cc7fc90a..b64090f48b 100644 --- a/narwhals/_plan/_immutable.py +++ b/narwhals/_plan/_immutable.py @@ -31,6 +31,8 @@ class Immutable(metaclass=ImmutableMeta): # NOTE: Trying to avoid this being added to synthesized `__init__` # Seems to be the only difference when decorating the metaclass __immutable_hash_value__: int + else: # pragma: no cover + ... __immutable_keys__: ClassVar[tuple[str, ...]] @@ -108,7 +110,9 @@ def __init__(self, **kwds: Any) -> None: def _field_str(name: str, value: Any) -> str: if isinstance(value, tuple): - inner = ", ".join(f"{v}" for v in value) + inner = ", ".join( + (f"{v!s}" if not isinstance(v, str) else f"{v!r}") for v in value + ) return f"{name}=[{inner}]" if isinstance(value, str): return f"{name}={value!r}" diff --git a/narwhals/_plan/_parse.py b/narwhals/_plan/_parse.py index b0f4f678a1..736b5e0abb 100644 --- a/narwhals/_plan/_parse.py +++ b/narwhals/_plan/_parse.py @@ -6,6 +6,7 @@ from itertools import chain from typing import TYPE_CHECKING +from narwhals._native import is_native_pandas from narwhals._plan._guards import ( is_column_name_or_selector, is_expr, @@ -13,20 +14,15 @@ is_iterable_reject, is_selector, ) -from narwhals._plan.exceptions import ( - invalid_into_expr_error, - is_iterable_pandas_error, - is_iterable_polars_error, -) +from narwhals._plan.exceptions import invalid_into_expr_error, is_iterable_error from narwhals._utils import qualified_type_name -from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series +from narwhals.dependencies import get_polars from narwhals.exceptions import InvalidOperationError if TYPE_CHECKING: from collections.abc import Iterator from typing import Any, TypeVar - import polars as pl from typing_extensions import TypeAlias, TypeIs from narwhals._plan.expr import Expr @@ -126,7 +122,7 @@ def parse_into_expr_ir( expr = col(input) elif isinstance(input, list): if list_as_series is None: - raise TypeError(input) + raise TypeError(input) # pragma: no cover expr = lit(list_as_series(input)) else: expr = lit(input, dtype=dtype) @@ -140,9 +136,9 @@ def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR: from narwhals._plan import selectors as cs selector = cs.by_name(input) - elif is_expr(input): + elif is_expr(input): # pragma: no cover selector = input.meta.as_selector() - else: + else: # pragma: no cover msg = f"cannot turn {qualified_type_name(input)!r} into selector" raise TypeError(msg) return selector._ir @@ -194,8 +190,8 @@ def _parse_sort_by_into_iter_expr_ir( ) -> Iterator[ExprIR]: for e in _parse_into_iter_expr_ir(by, *more_by): if e.is_scalar: - msg = f"All expressions sort keys must preserve length, but got:\n{e!r}" - raise InvalidOperationError(msg) + msg = f"All expressions sort keys must preserve length, but got:\n{e!r}" # pragma: no cover + raise InvalidOperationError(msg) # pragma: no cover yield e @@ -216,14 +212,14 @@ def _parse_into_iter_selector_ir( if not _is_empty_sequence(first_input): if _is_iterable(first_input) and not isinstance(first_input, str): - if more_inputs: + if more_inputs: # pragma: no cover raise invalid_into_expr_error(first_input, more_inputs, {}) else: for into in first_input: # type: ignore[var-annotated] yield parse_into_selector_ir(into) else: yield parse_into_selector_ir(first_input) - for into in more_inputs: + for into in more_inputs: # pragma: no cover yield parse_into_selector_ir(into) @@ -298,18 +294,13 @@ def _combine_predicates(predicates: Iterator[ExprIR], /) -> ExprIR: def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]: - if is_pandas_dataframe(obj) or is_pandas_series(obj): - raise is_iterable_pandas_error(obj) - if _is_polars(obj): - raise is_iterable_polars_error(obj) + if is_native_pandas(obj) or ( + (pl := get_polars()) + and isinstance(obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame)) + ): + raise is_iterable_error(obj) return isinstance(obj, Iterable) def _is_empty_sequence(obj: Any) -> bool: return isinstance(obj, Sequence) and not obj - - -def _is_polars(obj: Any) -> TypeIs[pl.Series | pl.Expr | pl.DataFrame | pl.LazyFrame]: - return (pl := get_polars()) is not None and isinstance( - obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame) - ) diff --git a/narwhals/_plan/_rewrites.py b/narwhals/_plan/_rewrites.py index fd26364e66..1a4d40fb7d 100644 --- a/narwhals/_plan/_rewrites.py +++ b/narwhals/_plan/_rewrites.py @@ -88,7 +88,7 @@ def map_ir( origin: NamedOrExprIRT, function: MapIR, *more_functions: MapIR ) -> NamedOrExprIRT: """Apply one or more functions, sequentially, to all of `origin`'s children.""" - if more_functions: + if more_functions: # pragma: no cover result = origin for fn in (function, *more_functions): result = result.map_ir(fn) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index 8ac2084034..16d32938e5 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -31,7 +31,7 @@ if sys.version_info >= (3, 13): from copy import replace as replace # noqa: PLC0414 -else: +else: # pragma: no cover def replace(obj: T, /, **changes: Any) -> T: cls = obj.__class__ @@ -98,20 +98,20 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: yield element # type: ignore[misc] -def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: +def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: # pragma: no cover msg = f"Expected one or an iterable of strings, but got: {qualified_type_name(obj)!r}\n{obj!r}" return TypeError(msg) def ensure_seq_str(obj: OneOrIterable[str], /) -> Seq[str]: if not isinstance(obj, Iterable): - raise _not_one_or_iterable_str_error(obj) + raise _not_one_or_iterable_str_error(obj) # pragma: no cover return (obj,) if isinstance(obj, str) else tuple(obj) def ensure_list_str(obj: OneOrIterable[str], /) -> list[str]: if not isinstance(obj, Iterable): - raise _not_one_or_iterable_str_error(obj) + raise _not_one_or_iterable_str_error(obj) # pragma: no cover return [obj] if isinstance(obj, str) else list(obj) @@ -246,7 +246,7 @@ def _not_enough_room_error(cls, prefix: str, n_chars: int, /) -> NarwhalsError: available_chars = n_chars - len_prefix if available_chars < 0: visualize = "" - else: + else: # pragma: no cover (has coverage, but there's randomness in the test) okay = "✔" * available_chars bad = "✖" * (cls._MIN_RANDOM_CHARS - available_chars) visualize = f"\n Preview: '{prefix}{okay}{bad}'" diff --git a/narwhals/_plan/dataframe.py b/narwhals/_plan/dataframe.py index db2d43b5c1..a2c8bc5a02 100644 --- a/narwhals/_plan/dataframe.py +++ b/narwhals/_plan/dataframe.py @@ -50,7 +50,7 @@ def version(self) -> Version: return self._version @property - def implementation(self) -> Implementation: + def implementation(self) -> Implementation: # pragma: no cover return self._compliant.implementation @property @@ -61,7 +61,7 @@ def schema(self) -> Schema: def columns(self) -> list[str]: return self._compliant.columns - def __repr__(self) -> str: # pragma: no cover + def __repr__(self) -> str: return generate_repr(f"nw.{type(self).__name__}", self.to_native().__repr__()) def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: @@ -70,12 +70,12 @@ def __init__(self, compliant: CompliantFrame[Any, NativeFrameT_co], /) -> None: def _with_compliant(self, compliant: CompliantFrame[Any, Incomplete], /) -> Self: return type(self)(compliant) - def to_native(self) -> NativeFrameT_co: + def to_native(self) -> NativeFrameT_co: # pragma: no cover return self._compliant.native def filter( self, *predicates: OneOrIterable[IntoExprColumn], **constraints: Any - ) -> Self: + ) -> Self: # pragma: no cover e = _parse.parse_predicates_constraints_into_expr_ir(*predicates, **constraints) named_irs, _ = prepare_projection((e,), schema=self) if len(named_irs) != 1: @@ -113,7 +113,9 @@ def sort( def drop(self, *columns: str, strict: bool = True) -> Self: return self._with_compliant(self._compliant.drop(columns, strict=strict)) - def drop_nulls(self, subset: str | Sequence[str] | None = None) -> Self: + def drop_nulls( + self, subset: str | Sequence[str] | None = None + ) -> Self: # pragma: no cover subset = [subset] if isinstance(subset, str) else subset return self._with_compliant(self._compliant.drop_nulls(subset)) @@ -130,7 +132,7 @@ class DataFrame( def implementation(self) -> _EagerAllowedImpl: return self._compliant.implementation - def __len__(self) -> int: + def __len__(self) -> int: # pragma: no cover return len(self._compliant) @property @@ -183,17 +185,17 @@ def to_dict( def to_dict( self, *, as_series: bool = True ) -> dict[str, Series[NativeSeriesT]] | dict[str, list[Any]]: - if as_series: + if as_series: # pragma: no cover return { key: self._series(value) for key, value in self._compliant.to_dict(as_series=as_series).items() } return self._compliant.to_dict(as_series=as_series) - def to_series(self, index: int = 0) -> Series[NativeSeriesT]: + def to_series(self, index: int = 0) -> Series[NativeSeriesT]: # pragma: no cover return self._series(self._compliant.to_series(index)) - def get_column(self, name: str) -> Series[NativeSeriesT]: + def get_column(self, name: str) -> Series[NativeSeriesT]: # pragma: no cover return self._series(self._compliant.get_column(name)) @overload @@ -253,7 +255,7 @@ def filter( **constraints, ) named_irs, _ = prepare_projection((e,), schema=self) - if len(named_irs) != 1: + if len(named_irs) != 1: # pragma: no cover # Should be unreachable, but I guess we will see msg = f"Expected a single predicate after expansion, but got {len(named_irs)!r}\n\n{named_irs!r}" raise ValueError(msg) diff --git a/narwhals/_plan/exceptions.py b/narwhals/_plan/exceptions.py index 9bd5eab854..42e392113b 100644 --- a/narwhals/_plan/exceptions.py +++ b/narwhals/_plan/exceptions.py @@ -6,6 +6,7 @@ from itertools import groupby from typing import TYPE_CHECKING +from narwhals._utils import qualified_type_name from narwhals.exceptions import ( ColumnNotFoundError, ComputeError, @@ -21,9 +22,6 @@ from collections.abc import Collection, Iterable from typing import Any - import pandas as pd - import polars as pl - from narwhals._plan import expressions as ir from narwhals._plan._function import Function from narwhals._plan.expressions.operators import Operator @@ -156,19 +154,9 @@ def invalid_into_expr_error( return InvalidIntoExprError(msg) -def is_iterable_pandas_error(obj: pd.DataFrame | pd.Series[Any], /) -> TypeError: - msg = ( - f"Expected Narwhals class or scalar, got: {type(obj)}. " - "Perhaps you forgot a `nw.from_native` somewhere?" - ) - return TypeError(msg) - - -def is_iterable_polars_error( - obj: pl.Series | pl.Expr | pl.DataFrame | pl.LazyFrame, / -) -> TypeError: +def is_iterable_error(obj: object, /) -> TypeError: msg = ( - f"Expected Narwhals class or scalar, got: {type(obj)}.\n\n" + f"Expected Narwhals class or scalar, got: {qualified_type_name(obj)!r}.\n\n" "Hint: Perhaps you\n" "- forgot a `nw.from_native` somewhere?\n" "- used `pl.col` instead of `nw.col`?" diff --git a/narwhals/_plan/expr.py b/narwhals/_plan/expr.py index 85aacc66e2..5dfa9e3770 100644 --- a/narwhals/_plan/expr.py +++ b/narwhals/_plan/expr.py @@ -250,16 +250,16 @@ def clip( it = parse_into_seq_of_expr_ir(lower_bound, upper_bound) return self._from_ir(F.Clip().to_function_expr(self._ir, *it)) - def cum_count(self, *, reverse: bool = False) -> Self: + def cum_count(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumCount(reverse=reverse)) - def cum_min(self, *, reverse: bool = False) -> Self: + def cum_min(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumMin(reverse=reverse)) - def cum_max(self, *, reverse: bool = False) -> Self: + def cum_max(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumMax(reverse=reverse)) - def cum_prod(self, *, reverse: bool = False) -> Self: + def cum_prod(self, *, reverse: bool = False) -> Self: # pragma: no cover return self._with_unary(F.CumProd(reverse=reverse)) def cum_sum(self, *, reverse: bool = False) -> Self: @@ -267,7 +267,7 @@ def cum_sum(self, *, reverse: bool = False) -> Self: def rolling_sum( self, window_size: int, *, min_samples: int | None = None, center: bool = False - ) -> Self: + ) -> Self: # pragma: no cover options = rolling_options(window_size, min_samples, center=center) return self._with_unary(F.RollingSum(options=options)) @@ -284,7 +284,7 @@ def rolling_var( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: + ) -> Self: # pragma: no cover options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingVar(options=options)) @@ -295,7 +295,7 @@ def rolling_std( min_samples: int | None = None, center: bool = False, ddof: int = 1, - ) -> Self: + ) -> Self: # pragma: no cover options = rolling_options(window_size, min_samples, center=center, ddof=ddof) return self._with_unary(F.RollingStd(options=options)) @@ -542,7 +542,7 @@ def name(self) -> ExprNameNamespace: return ExprNameNamespace(_expr=self) @property - def cat(self) -> ExprCatNamespace: + def cat(self) -> ExprCatNamespace: # pragma: no cover from narwhals._plan.expressions.categorical import ExprCatNamespace return ExprCatNamespace(_expr=self) @@ -560,7 +560,7 @@ def dt(self) -> ExprDateTimeNamespace: return ExprDateTimeNamespace(_expr=self) @property - def list(self) -> ExprListNamespace: + def list(self) -> ExprListNamespace: # pragma: no cover from narwhals._plan.expressions.lists import ExprListNamespace return ExprListNamespace(_expr=self) diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 3825a6ac69..ed51f4cb90 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -32,6 +32,8 @@ def __init__(self, *, expr: ExprIR, **kwds: Any) -> None: if expr.is_scalar: raise agg_scalar_error(self, expr) super().__init__(expr=expr, **kwds) # pyright: ignore[reportCallIssue] + else: # pragma: no cover + ... # fmt: off diff --git a/narwhals/_plan/expressions/boolean.py b/narwhals/_plan/expressions/boolean.py index ebc2a8643b..f6042a391d 100644 --- a/narwhals/_plan/expressions/boolean.py +++ b/narwhals/_plan/expressions/boolean.py @@ -3,12 +3,13 @@ # NOTE: Needed to avoid naming collisions # - Any import typing as t +from typing import TYPE_CHECKING from narwhals._plan._function import Function, HorizontalFunction from narwhals._plan.options import FEOptions, FunctionOptions from narwhals._typing_compat import TypeVar -if t.TYPE_CHECKING: +if TYPE_CHECKING: from typing_extensions import Self from narwhals._plan._expr_ir import ExprIR diff --git a/narwhals/_plan/expressions/categorical.py b/narwhals/_plan/expressions/categorical.py index 7c59fd4443..5bb7157f5d 100644 --- a/narwhals/_plan/expressions/categorical.py +++ b/narwhals/_plan/expressions/categorical.py @@ -20,7 +20,7 @@ class IRCatNamespace(IRNamespace): class ExprCatNamespace(ExprNamespace[IRCatNamespace]): @property def _ir_namespace(self) -> type[IRCatNamespace]: - return IRCatNamespace + return IRCatNamespace # pragma: no cover def get_categories(self) -> Expr: - return self._with_unary(self._ir.get_categories()) + return self._with_unary(self._ir.get_categories()) # pragma: no cover diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 01451cb39f..4007f633b9 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -3,6 +3,7 @@ from __future__ import annotations import typing as t +from typing import TYPE_CHECKING from narwhals._plan._expr_ir import ExprIR, SelectorIR from narwhals._plan.common import replace @@ -25,7 +26,7 @@ ) from narwhals.exceptions import InvalidOperationError -if t.TYPE_CHECKING: +if TYPE_CHECKING: from collections.abc import Container, Iterable, Iterator from typing_extensions import Self @@ -239,9 +240,12 @@ def iter_output_name(self) -> t.Iterator[ExprIR]: """ for e in self.input[:1]: yield from e.iter_output_name() + # NOTE: Covering the empty case doesn't make sense without implementing `FunctionFlags.ALLOW_EMPTY_INPUTS` + # https://github.com/pola-rs/polars/blob/df69276daf5d195c8feb71eef82cbe9804e0f47f/crates/polars-plan/src/plans/options.rs#L106-L107 + return # pragma: no cover # NOTE: Interacting badly with `pyright` synthesizing the `__replace__` signature - if not t.TYPE_CHECKING: + if not TYPE_CHECKING: def __init__( self, @@ -256,6 +260,8 @@ def __init__( raise function_expr_invalid_operation_error(function, parent) kwargs = dict(input=input, function=function, options=options, **kwds) super().__init__(**kwargs) + else: # pragma: no cover + ... def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str diff --git a/narwhals/_plan/expressions/lists.py b/narwhals/_plan/expressions/lists.py index 604e054a5e..b14090b985 100644 --- a/narwhals/_plan/expressions/lists.py +++ b/narwhals/_plan/expressions/lists.py @@ -21,7 +21,7 @@ class IRListNamespace(IRNamespace): class ExprListNamespace(ExprNamespace[IRListNamespace]): @property def _ir_namespace(self) -> type[IRListNamespace]: - return IRListNamespace + return IRListNamespace # pragma: no cover def len(self) -> Expr: - return self._with_unary(self._ir.len()) + return self._with_unary(self._ir.len()) # pragma: no cover diff --git a/narwhals/_plan/expressions/selectors.py b/narwhals/_plan/expressions/selectors.py index 1400e92fe5..bdf44cc255 100644 --- a/narwhals/_plan/expressions/selectors.py +++ b/narwhals/_plan/expressions/selectors.py @@ -9,7 +9,6 @@ import functools import re -from contextlib import suppress from typing import TYPE_CHECKING, Any, ClassVar, final from narwhals._plan._immutable import Immutable @@ -65,7 +64,7 @@ def into_columns( # - Column names in `ignored_columns` are only used if they are explicitly mentioned by a `ByName` or `ByIndex`. # - `ignored_columns` are only evaluated against `All` and `Matches` # https://github.com/pola-rs/polars/blob/2b241543851800595efd343be016b65cdbdd3c9f/crates/polars-plan/src/dsl/selector.rs#L192-L193 - msg = f"{type(self).__name__}.into_columns" + msg = f"{type(self).__name__}.into_columns" # pragma: no cover[abstract] raise NotImplementedError(msg) @@ -88,8 +87,7 @@ def matches(self, dtype: IntoDType) -> bool: The result will *only* be cached if this method is **not overridden**. Instead, use `DTypeSelector._matches` to customize the check. """ - # See https://github.com/python/typeshed/issues/6347 - return _selector_matches(self, dtype) # type: ignore[arg-type] + return _selector_matches(self, dtype) def _matches(self, dtype: IntoDType) -> bool: """Implementation of `DTypeSelector.matches`.""" @@ -108,15 +106,6 @@ def __repr__(self) -> str: def _matches(self, dtype: IntoDType) -> bool: return True - # Special case, needs to behave the same whether it is treated like a `DTypeSelector` or regular - def into_columns( - self, schema: FrozenSchema, ignored_columns: Container[str] - ) -> Iterator[str]: - if ignored_columns: - yield from (name for name in schema if name not in ignored_columns) - else: - yield from schema - class All(Selector): def to_dtype_selector(self) -> DTypeSelector: @@ -162,17 +151,18 @@ def into_columns( self, schema: FrozenSchema, ignored_columns: Container[str] ) -> Iterator[str]: names = schema.names + n_fields = len(names) if not self.require_all: - with suppress(IndexError): - for index in self.indices: - yield names[index] + if n_fields == 0: + yield from () + else: + yield from (names[idx] for idx in self.indices if abs(idx) < n_fields) else: - n_fields = len(names) - for index in self.indices: - positive_index = index + n_fields if index < 0 else index - if positive_index < 0 or positive_index >= n_fields: - raise column_index_error(index, schema) - yield names[index] + for idx in self.indices: + if abs(idx) < n_fields: + yield names[idx] + else: + raise column_index_error(idx, schema) class ByName(Selector): diff --git a/narwhals/_plan/expressions/strings.py b/narwhals/_plan/expressions/strings.py index 6e60a7b530..5478a7154c 100644 --- a/narwhals/_plan/expressions/strings.py +++ b/narwhals/_plan/expressions/strings.py @@ -84,30 +84,32 @@ class IRStringNamespace(IRNamespace): def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Replace: + ) -> Replace: # pragma: no cover return Replace(pattern=pattern, value=value, literal=literal, n=n) def replace_all( self, pattern: str, value: str, *, literal: bool = False - ) -> ReplaceAll: + ) -> ReplaceAll: # pragma: no cover return ReplaceAll(pattern=pattern, value=value, literal=literal) - def strip_chars(self, characters: str | None = None) -> StripChars: + def strip_chars( + self, characters: str | None = None + ) -> StripChars: # pragma: no cover return StripChars(characters=characters) def contains(self, pattern: str, *, literal: bool = False) -> Contains: return Contains(pattern=pattern, literal=literal) - def slice(self, offset: int, length: int | None = None) -> Slice: + def slice(self, offset: int, length: int | None = None) -> Slice: # pragma: no cover return Slice(offset=offset, length=length) - def head(self, n: int = 5) -> Slice: + def head(self, n: int = 5) -> Slice: # pragma: no cover return self.slice(0, n) - def tail(self, n: int = 5) -> Slice: + def tail(self, n: int = 5) -> Slice: # pragma: no cover return self.slice(-n) - def to_datetime(self, format: str | None = None) -> ToDatetime: + def to_datetime(self, format: str | None = None) -> ToDatetime: # pragma: no cover return ToDatetime(format=format) @@ -121,41 +123,43 @@ def len_chars(self) -> Expr: def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 - ) -> Expr: + ) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace(pattern, value, literal=literal, n=n)) - def replace_all(self, pattern: str, value: str, *, literal: bool = False) -> Expr: + def replace_all( + self, pattern: str, value: str, *, literal: bool = False + ) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace_all(pattern, value, literal=literal)) - def strip_chars(self, characters: str | None = None) -> Expr: + def strip_chars(self, characters: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.strip_chars(characters)) - def starts_with(self, prefix: str) -> Expr: + def starts_with(self, prefix: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.starts_with(prefix=prefix)) - def ends_with(self, suffix: str) -> Expr: + def ends_with(self, suffix: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.ends_with(suffix=suffix)) def contains(self, pattern: str, *, literal: bool = False) -> Expr: return self._with_unary(self._ir.contains(pattern, literal=literal)) - def slice(self, offset: int, length: int | None = None) -> Expr: + def slice(self, offset: int, length: int | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.slice(offset, length)) - def head(self, n: int = 5) -> Expr: + def head(self, n: int = 5) -> Expr: # pragma: no cover return self._with_unary(self._ir.head(n)) - def tail(self, n: int = 5) -> Expr: + def tail(self, n: int = 5) -> Expr: # pragma: no cover return self._with_unary(self._ir.tail(n)) - def split(self, by: str) -> Expr: + def split(self, by: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.split(by=by)) - def to_datetime(self, format: str | None = None) -> Expr: + def to_datetime(self, format: str | None = None) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_datetime(format)) - def to_lowercase(self) -> Expr: + def to_lowercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_lowercase()) - def to_uppercase(self) -> Expr: + def to_uppercase(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_uppercase()) diff --git a/narwhals/_plan/expressions/temporal.py b/narwhals/_plan/expressions/temporal.py index 11a87599ab..35a622ebd2 100644 --- a/narwhals/_plan/expressions/temporal.py +++ b/narwhals/_plan/expressions/temporal.py @@ -64,7 +64,7 @@ class Timestamp(TemporalFunction): def from_time_unit(time_unit: TimeUnit = "us", /) -> Timestamp: if not _is_polars_time_unit(time_unit): msg = f"invalid `time_unit` \n\nExpected one of ['ns', 'us', 'ms'], got {time_unit!r}." - raise ValueError(msg) + raise TypeError(msg) return Timestamp(time_unit=time_unit) def __repr__(self) -> str: @@ -115,64 +115,64 @@ class ExprDateTimeNamespace(ExprNamespace[IRDateTimeNamespace]): def _ir_namespace(self) -> type[IRDateTimeNamespace]: return IRDateTimeNamespace - def date(self) -> Expr: + def date(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.date()) - def year(self) -> Expr: + def year(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.year()) - def month(self) -> Expr: + def month(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.month()) - def day(self) -> Expr: + def day(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.day()) - def hour(self) -> Expr: + def hour(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.hour()) - def minute(self) -> Expr: + def minute(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.minute()) - def second(self) -> Expr: + def second(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.second()) - def millisecond(self) -> Expr: + def millisecond(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.millisecond()) - def microsecond(self) -> Expr: + def microsecond(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.microsecond()) - def nanosecond(self) -> Expr: + def nanosecond(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.nanosecond()) - def ordinal_day(self) -> Expr: + def ordinal_day(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.ordinal_day()) - def weekday(self) -> Expr: + def weekday(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.weekday()) def total_minutes(self) -> Expr: return self._with_unary(self._ir.total_minutes()) - def total_seconds(self) -> Expr: + def total_seconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_seconds()) - def total_milliseconds(self) -> Expr: + def total_milliseconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_milliseconds()) - def total_microseconds(self) -> Expr: + def total_microseconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_microseconds()) - def total_nanoseconds(self) -> Expr: + def total_nanoseconds(self) -> Expr: # pragma: no cover return self._with_unary(self._ir.total_nanoseconds()) - def to_string(self, format: str) -> Expr: + def to_string(self, format: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.to_string(format=format)) - def replace_time_zone(self, time_zone: str | None) -> Expr: + def replace_time_zone(self, time_zone: str | None) -> Expr: # pragma: no cover return self._with_unary(self._ir.replace_time_zone(time_zone=time_zone)) - def convert_time_zone(self, time_zone: str) -> Expr: + def convert_time_zone(self, time_zone: str) -> Expr: # pragma: no cover return self._with_unary(self._ir.convert_time_zone(time_zone=time_zone)) def timestamp(self, time_unit: TimeUnit = "us") -> Expr: diff --git a/narwhals/_plan/functions.py b/narwhals/_plan/functions.py index 06e43b9b18..8e930d1044 100644 --- a/narwhals/_plan/functions.py +++ b/narwhals/_plan/functions.py @@ -2,6 +2,7 @@ import builtins import typing as t +from typing import TYPE_CHECKING from narwhals._plan import _guards, _parse, common, expressions as ir, selectors as cs from narwhals._plan.expressions import functions as F @@ -11,7 +12,7 @@ from narwhals._plan.when_then import When from narwhals._utils import Version, flatten -if t.TYPE_CHECKING: +if TYPE_CHECKING: from narwhals._plan.expr import Expr from narwhals._plan.series import Series from narwhals._plan.typing import IntoExpr, IntoExprColumn, NativeSeriesT diff --git a/narwhals/_plan/meta.py b/narwhals/_plan/meta.py index c6192f681d..41487b503b 100644 --- a/narwhals/_plan/meta.py +++ b/narwhals/_plan/meta.py @@ -84,9 +84,6 @@ def as_selector(self) -> Selector: Raises if the underlying expressions is not a column or selector. """ - if not self.is_column_selection(): - msg = f"cannot turn `{self._ir!r}` into a selector" - raise InvalidOperationError(msg) return self._ir.to_selector_ir().to_narwhals() diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 739654e4bf..9d93a6627d 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -59,7 +59,7 @@ def returns_scalar(self) -> bool: return FunctionFlags.RETURNS_SCALAR in self def is_length_preserving(self) -> bool: - return FunctionFlags.LENGTH_PRESERVING in self + return FunctionFlags.LENGTH_PRESERVING in self # pragma: no cover def is_row_separable(self) -> bool: return FunctionFlags.ROW_SEPARABLE in self @@ -92,7 +92,7 @@ def returns_scalar(self) -> bool: return self.flags.returns_scalar() def is_length_preserving(self) -> bool: - return self.flags.is_length_preserving() + return self.flags.is_length_preserving() # pragma: no cover def is_row_separable(self) -> bool: return self.flags.is_row_separable() @@ -102,8 +102,8 @@ def is_input_wildcard_expansion(self) -> bool: def with_flags(self, flags: FunctionFlags, /) -> FunctionOptions: if (FunctionFlags.RETURNS_SCALAR | FunctionFlags.LENGTH_PRESERVING) in flags: - msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." - raise TypeError(msg) + msg = "A function cannot both return a scalar and preserve length, they are mutually exclusive." # pragma: no cover + raise TypeError(msg) # pragma: no cover obj = FunctionOptions.__new__(FunctionOptions) object.__setattr__(obj, "flags", self.flags | flags) return obj @@ -155,10 +155,6 @@ def __repr__(self) -> str: args = f"descending={self.descending!r}, nulls_last={self.nulls_last!r}" return f"{type(self).__name__}({args})" - @staticmethod - def default() -> SortOptions: - return SortOptions(descending=False, nulls_last=False) - def to_arrow(self) -> pc.ArraySortOptions: import pyarrow.compute as pc @@ -171,7 +167,7 @@ def to_multiple(self, n_repeat: int = 1, /) -> SortMultipleOptions: if n_repeat == 1: desc: Seq[bool] = (self.descending,) nulls: Seq[bool] = (self.nulls_last,) - else: + else: # pragma: no cover desc = tuple(repeat(self.descending, n_repeat)) nulls = tuple(repeat(self.nulls_last, n_repeat)) return SortMultipleOptions(descending=desc, nulls_last=nulls) @@ -201,7 +197,7 @@ def _to_arrow_args( ) -> tuple[Sequence[tuple[str, Order]], NullPlacement]: first = self.nulls_last[0] if len(self.nulls_last) != 1 and any(x != first for x in self.nulls_last[1:]): - msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" + msg = f"pyarrow doesn't support multiple values for `nulls_last`, got: {self.nulls_last!r}" # pragma: no cover raise NotImplementedError(msg) if len(self.descending) == 1: descending: Iterable[bool] = repeat(self.descending[0], len(by)) @@ -219,7 +215,9 @@ def to_arrow(self, by: Sequence[str]) -> pc.SortOptions: sort_keys, placement = self._to_arrow_args(by) return pc.SortOptions(sort_keys=sort_keys, null_placement=placement) - def to_arrow_acero(self, by: Sequence[str]) -> pyarrow.acero.Declaration: + def to_arrow_acero( + self, by: Sequence[str] + ) -> pyarrow.acero.Declaration: # pragma: no cover from narwhals._plan.arrow import acero sort_keys, placement = self._to_arrow_args(by) @@ -291,7 +289,7 @@ def __repr__(self) -> str: return self.__str__() @classmethod - def default(cls) -> Self: + def default(cls) -> Self: # pragma: no cover[abstract] return cls(is_namespaced=False, override_name="") @classmethod diff --git a/narwhals/_plan/schema.py b/narwhals/_plan/schema.py index 10c6665d08..2c15cf7804 100644 --- a/narwhals/_plan/schema.py +++ b/narwhals/_plan/schema.py @@ -4,7 +4,7 @@ from functools import lru_cache from itertools import chain from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, final, overload from narwhals._plan._expr_ir import NamedIR from narwhals._plan._immutable import Immutable @@ -12,7 +12,7 @@ from narwhals.dtypes import Unknown if TYPE_CHECKING: - from collections.abc import ItemsView, Iterator, KeysView, ValuesView + from collections.abc import ItemsView, Iterable, Iterator, KeysView, ValuesView from typing_extensions import Never, TypeAlias, TypeIs @@ -22,7 +22,7 @@ IntoFrozenSchema: TypeAlias = ( - "IntoSchema | Iterator[tuple[str, DType]] | FrozenSchema | HasSchema" + "IntoSchema | Iterable[tuple[str, DType]] | FrozenSchema | HasSchema" ) """A schema to freeze, or an already frozen one. @@ -35,6 +35,7 @@ _T2 = TypeVar("_T2") +@final class FrozenSchema(Immutable): """Use `freeze_schema(...)` constructor to trigger caching!""" @@ -42,7 +43,7 @@ class FrozenSchema(Immutable): _mapping: MappingProxyType[str, DType] def __init_subclass__(cls, *_: Never, **__: Never) -> Never: - msg = f"Cannot subclass {cls.__name__!r}" + msg = f"Cannot subclass {FrozenSchema.__name__!r}" raise TypeError(msg) def merge(self, other: FrozenSchema, /) -> FrozenSchema: @@ -69,7 +70,7 @@ def select(self, exprs: Seq[NamedIR]) -> FrozenSchema: def select_irs(self, exprs: Seq[NamedIR]) -> Seq[NamedIR]: return exprs - def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: + def with_columns(self, exprs: Seq[NamedIR]) -> FrozenSchema: # pragma: no cover # similar to `merge`, but preserving known `DType`s names = (e.name for e in exprs) default = Unknown() diff --git a/narwhals/_plan/selectors.py b/narwhals/_plan/selectors.py index f831136b69..3560827074 100644 --- a/narwhals/_plan/selectors.py +++ b/narwhals/_plan/selectors.py @@ -8,7 +8,7 @@ from narwhals._plan import expressions as ir from narwhals._plan._guards import is_column from narwhals._plan.common import flatten_hash_safe -from narwhals._plan.expr import Expr +from narwhals._plan.expr import Expr, ExprV1 from narwhals._plan.expressions import operators as ops, selectors as s_ir from narwhals._utils import Version from narwhals.dtypes import DType @@ -39,13 +39,8 @@ def _from_ir(cls, selector_ir: ir.SelectorIR, /) -> Self: # type: ignore[overri return obj def as_expr(self) -> Expr: - if self.version is Version.MAIN: - return Expr._from_ir(self._ir) - if self.version is Version.V1: - from narwhals._plan.expr import ExprV1 - - return ExprV1._from_ir(self._ir) - raise NotImplementedError(self.version) + tp = Expr if self.version is Version.MAIN else ExprV1 + return tp._from_ir(self._ir) def exclude(self, *names: OneOrIterable[str]) -> Selector: return self - by_name(*names) # pyright: ignore[reportReturnType] @@ -59,9 +54,6 @@ def __add__(self, other: Any) -> Expr: # type: ignore[override] return self.as_expr().__add__(other) def __radd__(self, other: Any) -> Expr: # type: ignore[override] - if isinstance(other, type(self)): - msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')" - raise TypeError(msg) return self.as_expr().__radd__(other) @overload # type: ignore[override] diff --git a/narwhals/_plan/series.py b/narwhals/_plan/series.py index 56985c120b..255a8442e0 100644 --- a/narwhals/_plan/series.py +++ b/narwhals/_plan/series.py @@ -56,8 +56,9 @@ def from_iterable( ) ) raise NotImplementedError(implementation) - msg = f"{implementation} support in Narwhals is lazy-only" - raise ValueError(msg) + else: # pragma: no cover # noqa: RET506 + msg = f"{implementation} support in Narwhals is lazy-only" + raise ValueError(msg) @classmethod def from_native( @@ -76,7 +77,7 @@ def to_native(self) -> NativeSeriesT_co: def to_list(self) -> list[Any]: return self._compliant.to_list() - def __iter__(self) -> Iterator[Any]: + def __iter__(self) -> Iterator[Any]: # pragma: no cover yield from self.to_native() def alias(self, name: str) -> Self: diff --git a/narwhals/_plan/typing.py b/narwhals/_plan/typing.py index d8a55eb455..005b50c115 100644 --- a/narwhals/_plan/typing.py +++ b/narwhals/_plan/typing.py @@ -1,10 +1,11 @@ from __future__ import annotations import typing as t +from typing import TYPE_CHECKING from narwhals._typing_compat import TypeVar -if t.TYPE_CHECKING: +if TYPE_CHECKING: from collections.abc import Callable, Iterable from typing_extensions import TypeAlias diff --git a/narwhals/_plan/when_then.py b/narwhals/_plan/when_then.py index 20c1f2c72b..ce51e19087 100644 --- a/narwhals/_plan/when_then.py +++ b/narwhals/_plan/when_then.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan._guards import is_expr from narwhals._plan._immutable import Immutable from narwhals._plan._parse import ( parse_into_expr_ir as _parse_into_expr_ir, @@ -35,10 +34,6 @@ class When(Immutable): def then(self, expr: IntoExpr, /) -> Then: return Then(condition=self.condition, statement=parse_into_expr_ir(expr)) - @staticmethod - def _from_expr(expr: Expr, /) -> When: - return When(condition=expr._ir) - @staticmethod def _from_ir(expr_ir: ExprIR, /) -> When: return When(condition=expr_ir) @@ -71,10 +66,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] return Expr._from_ir(expr_ir) - def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] - if is_expr(value): - return super(Expr, self).__eq__(value) - return super().__eq__(value) + def __eq__(self, other: IntoExpr) -> Expr: # type: ignore[override] + return Expr.__eq__(self, other) class ChainedWhen(Immutable): @@ -119,10 +112,8 @@ def _ir(self) -> ExprIR: # type: ignore[override] def _from_ir(cls, expr_ir: ExprIR, /) -> Expr: # type: ignore[override] return Expr._from_ir(expr_ir) - def __eq__(self, value: object) -> Expr | bool: # type: ignore[override] - if is_expr(value): - return super(Expr, self).__eq__(value) - return super().__eq__(value) + def __eq__(self, other: IntoExpr) -> Expr: # type: ignore[override] + return Expr.__eq__(self, other) def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> TernaryExpr: diff --git a/pyproject.toml b/pyproject.toml index c248f3be6e..09b0dece08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -300,8 +300,10 @@ omit = [ 'narwhals/_ibis/typing.py', # Remove after finishing eager sub-protocols 'narwhals/_compliant/namespace.py', - # Doesn't have a full impl yet - 'narwhals/_plan/*' + # NOTE: Gradually adding as things become more stable + 'narwhals/_plan/arrow/*', + 'narwhals/_plan/compliant/*', + 'narwhals/_plan/**/typing.py', ] exclude_also = [ "if sys.version_info() <", @@ -321,7 +323,14 @@ exclude_also = [ 'if ".*" in str\(constructor', 'pytest.skip\(', 'assert_never\(', - 'PANDAS_VERSION < \(' + 'PANDAS_VERSION < \(', + 'def __repr__', + 'def __str__', + # Extends a `covdefaults` pattern to account for `EM10{1,2}` + # https://github.com/asottile/covdefaults/blob/a5228df597ffc7933bb2fb5b7bad94119a40a896/covdefaults.py#L90-L92 + # https://docs.astral.sh/ruff/rules/raw-string-in-exception/ + # https://docs.astral.sh/ruff/rules/f-string-in-exception/ + '^\s*msg = .+\n\s*raise NotImplementedError\b' ] [tool.mypy] diff --git a/tests/plan/expr_parsing_test.py b/tests/plan/expr_parsing_test.py index 6d510acabc..fd3b51f5e5 100644 --- a/tests/plan/expr_parsing_test.py +++ b/tests/plan/expr_parsing_test.py @@ -15,14 +15,16 @@ from narwhals._plan._parse import parse_into_seq_of_expr_ir from narwhals._plan.expressions import functions as F, operators as ops from narwhals._plan.expressions.literal import SeriesLiteral +from narwhals._plan.expressions.ranges import IntRange from narwhals.exceptions import ( ComputeError, InvalidIntoExprError, InvalidOperationError, InvalidOperationError as LengthChangingExprError, + MultiOutputExpressionError, ShapeError, ) -from tests.plan.utils import assert_expr_ir_equal +from tests.plan.utils import assert_expr_ir_equal, re_compile if TYPE_CHECKING: from contextlib import AbstractContextManager @@ -127,7 +129,7 @@ def test_valid_windows() -> None: assert nwp.sum_horizontal(a.diff().abs(), a.cum_sum()).over(order_by="i") -def test_invalid_repeat_agg() -> None: +def test_repeat_agg_invalid() -> None: with pytest.raises(InvalidOperationError): nwp.col("a").mean().mean() with pytest.raises(InvalidOperationError): @@ -144,7 +146,7 @@ def test_invalid_repeat_agg() -> None: # NOTE: Previously multiple different errors, but they can be reduced to the same thing # Once we are scalar, only elementwise is allowed -def test_invalid_agg_non_elementwise() -> None: +def test_agg_non_elementwise_invalid() -> None: pattern = re.compile(r"cannot use.+rank.+aggregated.+mean", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").mean().rank() @@ -167,7 +169,7 @@ def test_agg_non_elementwise_range_special() -> None: assert isinstance(e_ir.expr.input[1], ir.Len) -def test_invalid_int_range() -> None: +def test_int_range_invalid() -> None: pattern = re.compile(r"scalar.+agg", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.int_range(nwp.col("a")) @@ -177,16 +179,37 @@ def test_invalid_int_range() -> None: nwp.int_range(0, nwp.col("a").abs()) with pytest.raises(InvalidOperationError, match=pattern): nwp.int_range(nwp.col("a") + 1) + with pytest.raises(InvalidOperationError, match=pattern): + nwp.int_range((1 + nwp.col("b")).name.keep()) + int_range = IntRange(step=1, dtype=nw.Int64()) + with pytest.raises(InvalidOperationError, match=r"at least 2 inputs.+int_range"): + int_range.to_function_expr(ir.col("a")) + + +@pytest.mark.xfail( + reason="Not implemented `int_range(eager=True)`", raises=NotImplementedError +) +def test_int_range_series() -> None: + assert isinstance(nwp.int_range(50, eager=True), nwp.Series) -# NOTE: Non-`polars`` rule -def test_invalid_over() -> None: +def test_over_invalid() -> None: + with pytest.raises(TypeError, match=r"one of.+partition_by.+or.+order_by"): + nwp.col("a").last().over() + + # NOTE: Non-`polars` rule pattern = re.compile(r"cannot use.+over.+elementwise", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").fill_null(3).over("b") + # NOTE: This version isn't elementwise + expr_ir = nwp.col("a").fill_null(strategy="backward").over("b")._ir + assert isinstance(expr_ir, ir.WindowExpr) + assert isinstance(expr_ir.expr, ir.FunctionExpr) + assert isinstance(expr_ir.expr.function, F.FillNullWithStrategy) + -def test_nested_over() -> None: +def test_over_nested() -> None: pattern = re.compile(r"cannot nest.+over", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").mean().over("b").over("c") @@ -196,7 +219,7 @@ def test_nested_over() -> None: # NOTE: This *can* error in polars, but only if the length **actually changes** # The rule then breaks down to needing the same length arrays in all parts of the over -def test_filtration_over() -> None: +def test_over_filtration() -> None: pattern = re.compile(r"cannot use.+over.+change length", re.IGNORECASE) with pytest.raises(InvalidOperationError, match=pattern): nwp.col("a").drop_nulls().over("b") @@ -206,9 +229,9 @@ def test_filtration_over() -> None: nwp.col("a").diff().drop_nulls().over("b", order_by="i") -def test_invalid_binary_expr_length_changing() -> None: +def test_binary_expr_length_changing_invalid() -> None: a = nwp.col("a") - b = nwp.col("b") + b = nwp.col("b").exp() with pytest.raises(LengthChangingExprError): a.unique() + b.unique() @@ -247,7 +270,7 @@ def test_binary_expr_length_changing_agg() -> None: ) -def test_invalid_binary_expr_shape() -> None: +def test_binary_expr_shape_invalid() -> None: pattern = re.compile( re.escape("Cannot combine length-changing expressions with length-preserving"), re.IGNORECASE, @@ -261,6 +284,8 @@ def test_invalid_binary_expr_shape() -> None: a.map_batches(lambda x: x, is_elementwise=True) * b.gather_every(1, 0) with pytest.raises(ShapeError, match=pattern): a / b.drop_nulls() + with pytest.raises(ShapeError, match=pattern): + a.fill_null(1) // b.rolling_mean(5) @pytest.mark.parametrize("into_iter", [list, tuple, deque, iter, dict.fromkeys, set]) @@ -306,7 +331,7 @@ def test_is_in_series() -> None: ), ], ) -def test_invalid_is_in(other: Any, context: AbstractContextManager[Any]) -> None: +def test_is_in_invalid(other: Any, context: AbstractContextManager[Any]) -> None: with context: nwp.col("a").is_in(other) @@ -508,3 +533,80 @@ def test_hist_invalid() -> None: a.hist(deque((3, 2, 1))) with pytest.raises(TypeError): a.hist(1) # type: ignore[arg-type] + + +def test_into_expr_invalid() -> None: + pytest.importorskip("polars") + import polars as pl + + with pytest.raises( + TypeError, match=re_compile(r"expected.+narwhals.+got.+polars.+hint") + ): + nwp.col("a").max().over(pl.col("b")) # type: ignore[arg-type] + + +def test_when_invalid() -> None: + pattern = re_compile(r"multi-output expr.+not supported in.+when.+context") + + when = nwp.when(nwp.col("a", "b", "c").is_finite()) + when_then = when.then(nwp.col("d").is_unique()) + when_then_when = when_then.when( + (nwp.median("a", "b", "c") > 2) | nwp.col("d").is_nan() + ) + with pytest.raises(MultiOutputExpressionError, match=pattern): + when.then(nwp.max("c", "d")) + with pytest.raises(MultiOutputExpressionError, match=pattern): + when_then.otherwise(nwp.min("h", "i", "j")) + with pytest.raises(MultiOutputExpressionError, match=pattern): + when_then_when.then(nwp.col(["b", "y", "e"])) + + +# NOTE: `Then`, `ChainedThen` use multi-inheritance, but **need** to use `Expr.__eq__` +def test_then_equal() -> None: + expr = nwp.col("a").clip(nwp.col("a").kurtosis(), nwp.col("a").log()) + other = "other" + then = nwp.when(a="b").then(nwp.col("c").skew()) + chained_then = then.when("d").then("e") + + assert isinstance(then == expr, nwp.Expr) + assert isinstance(then == other, nwp.Expr) + + assert isinstance(chained_then == expr, nwp.Expr) + assert isinstance(chained_then == other, nwp.Expr) + + assert isinstance(then == chained_then, nwp.Expr) + + +def test_dt_timestamp_invalid() -> None: + assert nwp.col("a").dt.timestamp() + with pytest.raises( + TypeError, match=re_compile(r"invalid.+time_unit.+expected.+got 's'") + ): + nwp.col("a").dt.timestamp("s") + + +def test_dt_truncate_invalid() -> None: + assert nwp.col("a").dt.truncate("1d") + with pytest.raises(ValueError, match=re_compile(r"invalid.+every.+abcd")): + nwp.col("a").dt.truncate("abcd") + + +def test_replace_strict() -> None: + a = nwp.col("a") + remapping = a.replace_strict({1: 3, 2: 4}, return_dtype=nw.Int8) + sequences = a.replace_strict(old=[1, 2], new=[3, 4], return_dtype=nw.Int8()) + assert_expr_ir_equal(remapping, sequences) + + +def test_replace_strict_invalid() -> None: + with pytest.raises( + TypeError, + match="`new` argument is required if `old` argument is not a Mapping type", + ): + nwp.col("a").replace_strict("b") + + with pytest.raises( + TypeError, + match="`new` argument cannot be used if `old` argument is a Mapping type", + ): + nwp.col("a").replace_strict(old={1: 2, 3: 4}, new=[5, 6, 7]) diff --git a/tests/plan/frame_partition_by_test.py b/tests/plan/frame_partition_by_test.py index ea0565ac94..429a78fb20 100644 --- a/tests/plan/frame_partition_by_test.py +++ b/tests/plan/frame_partition_by_test.py @@ -8,8 +8,8 @@ import narwhals as nw from narwhals._plan import Selector, selectors as ncs from narwhals._utils import zip_strict -from narwhals.exceptions import ColumnNotFoundError, ComputeError -from tests.plan.utils import assert_equal_data, dataframe +from narwhals.exceptions import ColumnNotFoundError, ComputeError, DuplicateError +from tests.plan.utils import assert_equal_data, dataframe, re_compile if TYPE_CHECKING: from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable @@ -126,6 +126,12 @@ def test_partition_by_missing_names(data: Data) -> None: df.partition_by("c", "e") +def test_partition_by_duplicate_names(data: Data) -> None: + df = dataframe(data) + with pytest.raises(DuplicateError, match=re_compile(r"expected.+unique.+got.+'c'")): + df.partition_by("c", ncs.numeric()) + + def test_partition_by_fully_empty_selector(data: Data) -> None: df = dataframe(data) with pytest.raises( diff --git a/tests/plan/immutable_test.py b/tests/plan/immutable_test.py index 8e60759d0b..7a286520e5 100644 --- a/tests/plan/immutable_test.py +++ b/tests/plan/immutable_test.py @@ -212,3 +212,25 @@ def test_immutable___slots___(immutable_type: type[Immutable]) -> None: slots = immutable_type.__slots__ if slots: assert len(slots) != 0, slots + + +def test_immutable_str() -> None: + class MixedFields(Immutable): + __slots__ = ("name", "unique_id", "aliases") # noqa: RUF023 + name: str + unique_id: int + aliases: tuple[str, str, str] + + class Parent(Immutable): + __slots__ = ("children",) + children: tuple[MixedFields, ...] + + bob = MixedFields(name="bob", unique_id=123, aliases=("robert", "bobert", "Bob")) + parent = Parent(children=(bob,)) + + expected_child = ( + "MixedFields(name='bob', unique_id=123, aliases=['robert', 'bobert', 'Bob'])" + ) + expected_parent = f"Parent(children=[{expected_child}])" + assert str(bob) == expected_child + assert str(parent) == expected_parent diff --git a/tests/plan/repr_test.py b/tests/plan/repr_test.py new file mode 100644 index 0000000000..325ed52621 --- /dev/null +++ b/tests/plan/repr_test.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import narwhals._plan as nwp + + +def test_repr() -> None: + nwp.col("a").meta.as_selector() + expr = nwp.col("a") + selector = expr.meta.as_selector() + + expr_repr_html = expr._repr_html_() + expr_ir_repr_html = expr._ir._repr_html_() + selector_repr_html = selector._repr_html_() + selector_ir_repr_html = selector._ir._repr_html_() + expr_repr = expr.__repr__() + expr_ir_repr = expr._ir.__repr__() + selector_repr = selector.__repr__() + selector_ir_repr = selector._ir.__repr__() + + # In a notebook, both `Expr` and `ExprIR` are displayed the same + assert expr_repr_html == expr_ir_repr_html + # The actual repr (for debugging) has more information + assert expr_repr != expr_repr_html + # Currently, all extra information is *before* the part which matches + assert expr_repr.endswith(expr_repr_html) + # But these guys should not deviate + assert expr_ir_repr == expr_ir_repr_html + # The same invariants should hold for `Selector` and `SelectorIR` + assert selector_repr_html == selector_ir_repr_html + assert selector_repr != selector_repr_html + assert selector_repr.endswith(selector_repr_html) + assert selector_ir_repr == selector_ir_repr_html + # But they must still be visually different from `Expr` and `ExprIR` + assert selector_repr_html != expr_repr_html + assert selector_repr != expr_repr diff --git a/tests/plan/schema_test.py b/tests/plan/schema_test.py new file mode 100644 index 0000000000..19ecedac77 --- /dev/null +++ b/tests/plan/schema_test.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import pytest + +import narwhals as nw +from narwhals._plan.schema import FrozenSchema, freeze_schema +from tests.plan.utils import dataframe + + +def test_schema() -> None: + mapping = {"a": nw.Int64(), "b": nw.String()} + schema = nw.Schema(mapping) + frozen_schema = freeze_schema(mapping) + + assert frozen_schema.keys() == schema.keys() + assert tuple(frozen_schema.values()) == tuple(schema.values()) + + # NOTE: Would type-check if `Schema.__init__` didn't make liskov unhappy + assert schema == nw.Schema(frozen_schema) # type: ignore[arg-type] + assert mapping == dict(frozen_schema) + + assert frozen_schema == freeze_schema(mapping) + assert frozen_schema == freeze_schema(**mapping) + assert frozen_schema == freeze_schema(a=nw.Int64(), b=nw.String()) + assert frozen_schema == freeze_schema(schema) + assert frozen_schema == freeze_schema(frozen_schema) + assert frozen_schema == freeze_schema(frozen_schema.items()) + + # NOTE: Using `**` unpacking, despite not inheriting from `Mapping` or `dict` + assert frozen_schema == freeze_schema(**frozen_schema) + + # NOTE: Using `HasSchema` + df = dataframe({"a": [1, 2, 3], "b": ["c", "d", "e"]}) + assert frozen_schema == freeze_schema(df) + + # NOTE: In case this all looks *too good* to be true + assert frozen_schema != freeze_schema(**mapping, c=nw.Float64()) + + assert frozen_schema["a"] == schema["a"] + + assert frozen_schema.get("c") is None + assert frozen_schema.get("c", nw.Unknown) is nw.Unknown + assert frozen_schema.get("c", nw.Unknown()) == nw.Unknown() + + assert "b" in frozen_schema + assert "e" not in frozen_schema + + with pytest.raises(TypeError, match="Cannot subclass 'FrozenSchema'"): + + class MutableSchema(FrozenSchema): ... # type: ignore[misc] diff --git a/tests/plan/selectors_test.py b/tests/plan/selectors_test.py index 6e00ba7e22..f1170db354 100644 --- a/tests/plan/selectors_test.py +++ b/tests/plan/selectors_test.py @@ -5,6 +5,7 @@ from __future__ import annotations +import operator from datetime import timezone from typing import TYPE_CHECKING @@ -13,14 +14,23 @@ import narwhals as nw import narwhals.stable.v1 as nw_v1 from narwhals import _plan as nwp -from narwhals._plan import Expr, Selector, selectors as ncs +from narwhals._plan import Selector, selectors as ncs +from narwhals._plan._guards import is_expr, is_selector from narwhals._utils import zip_strict from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError -from tests.plan.utils import Frame, assert_expr_ir_equal, named_ir, re_compile +from tests.plan.utils import ( + Frame, + assert_expr_ir_equal, + assert_not_selector, + is_expr_ir_equal, + named_ir, + re_compile, +) if TYPE_CHECKING: from collections.abc import Iterable + from narwhals._plan.typing import IntoExpr, OperatorFn from narwhals.dtypes import DType @@ -234,6 +244,10 @@ def test_selector_by_index(schema_non_nested: nw.Schema) -> None: ~ncs.by_index(range(0, df.width, 2)), "bbb", "def", "fgg", "JJK", "opp" ) + df.assert_selects(ncs.by_index(0, 999, require_all=False), "abc") + df.assert_selects(ncs.by_index(-1, -999, require_all=False), "qqR") + df.assert_selects(ncs.by_index(1234, 5678, require_all=False)) + def test_selector_by_index_invalid_input() -> None: with pytest.raises(TypeError): @@ -249,12 +263,20 @@ def test_selector_by_index_not_found(schema_non_nested: nw.Schema) -> None: with pytest.raises(ColumnNotFoundError): df.project(ncs.by_index(999)) + df.assert_selects(ncs.by_index(999, -50, require_all=False)) + + df = Frame(nw.Schema()) + df.assert_selects(ncs.by_index(111, -112, require_all=False)) + def test_selector_by_index_reordering(schema_non_nested: nw.Schema) -> None: df = Frame(schema_non_nested) df.assert_selects(ncs.by_index(-3, -2, -1), "Lmn", "opp", "qqR") df.assert_selects(ncs.by_index(range(-3, 0)), "Lmn", "opp", "qqR") + df.assert_selects( + ncs.by_index(-3, 999, -2, -1, -48, require_all=False), "Lmn", "opp", "qqR" + ) def test_selector_by_name(schema_non_nested: nw.Schema) -> None: @@ -435,7 +457,7 @@ def test_selector_expansion() -> None: nwp.col("a").max().meta.as_selector() -def test_selector_sets(schema_non_nested: nw.Schema, schema_mixed: nw.Schema) -> None: +def test_selector_set_ops(schema_non_nested: nw.Schema, schema_mixed: nw.Schema) -> None: df = Frame(schema_non_nested) # NOTE: `cs.temporal` is used a lot in this tests, but `narwhals` doesn't have it @@ -456,12 +478,11 @@ def test_selector_sets(schema_non_nested: nw.Schema, schema_mixed: nw.Schema) -> # Would allow: `str | Expr | DType | type[DType] | Selector | Collection[str | Expr | DType | type[DType] | Selector]` selector = ncs.all() - (~temporal | ncs.matches(r"opp|JJK")) df.assert_selects(selector, "ghi", "Lmn") + selector = nwp.all().exclude("opp", "JJK").meta.as_selector() - (~temporal) + df.assert_selects(selector, "ghi", "Lmn") sub_expr = ncs.matches("[yz]$") - nwp.col("colx") - assert not isinstance(sub_expr, Selector), ( - "('Selector' - 'Expr') shouldn't behave as set" - ) - assert isinstance(sub_expr, Expr) + assert_not_selector(sub_expr) with pytest.raises(TypeError, match=r"unsupported .* \('Expr' - 'Selector'\)"): nwp.col("colx") - ncs.matches("[yz]$") @@ -482,6 +503,54 @@ def test_selector_sets(schema_non_nested: nw.Schema, schema_mixed: nw.Schema) -> df.assert_selects(selector, "l", "m", "n", "o", "p", "q", "r", "s", "u") +def _is_binary_operator(function: OperatorFn) -> bool: + return function in {operator.and_, operator.or_, operator.xor} + + +def _is_selector_operator(function: OperatorFn) -> bool: + return function in {operator.and_, operator.or_, operator.xor, operator.sub} + + +@pytest.mark.parametrize( + "arg_2", + [1, nwp.col("a"), nwp.col("a").max(), ncs.numeric()], + ids=["Scalar", "Column", "Expr", "Selector"], +) +@pytest.mark.parametrize( + "function", [operator.and_, operator.or_, operator.xor, operator.add, operator.sub] +) +def test_selector_arith_binary_ops( + arg_2: IntoExpr | Selector, function: OperatorFn +) -> None: + # NOTE: These are the `polars.selectors` semantics + # Parts of it may change with `polars>=2.0`, due to how confusing they are + arg_1 = ncs.string() + result_1 = function(arg_1, arg_2) + if ( + _is_binary_operator(function) + and is_expr(arg_2) + and is_expr_ir_equal(arg_2, nwp.col("a")) + ) or (_is_selector_operator(function) and is_selector(arg_2)): + assert is_selector(result_1) + else: + assert_not_selector(result_1) + + if _is_binary_operator(function) and is_selector(arg_2): + result_2 = function(arg_2, arg_1) + assert is_selector(result_2) + # `__sub__` is allowed, but `__rsub__` is not ... + elif function is not operator.sub: + result_2 = function(arg_2, arg_1) + assert_not_selector(result_2) + # ... unless both are `Selector` + elif is_selector(arg_2): + result_2 = function(arg_2, arg_1) + assert is_selector(result_2) + else: + with pytest.raises(TypeError): + function(arg_2, arg_1) + + @pytest.mark.parametrize( "selector", [ @@ -499,10 +568,20 @@ def test_selector_result_order(schema_non_nested: nw.Schema, selector: Selector) def test_selector_list(schema_nested_1: nw.Schema) -> None: df = Frame(schema_nested_1) + + # inner None df.assert_selects(ncs.list(), "b", "c", "e") + # Inner All (as a DTypeSelector) df.assert_selects(ncs.list(ncs.all()), "b", "c", "e") - df.assert_selects(ncs.list(inner=ncs.numeric()), "b", "c") + # inner DTypeSelector + df.assert_selects(ncs.list(ncs.numeric()), "b", "c") df.assert_selects(ncs.list(inner=ncs.string()), "e") + # inner BinarySelector + df.assert_selects( + ncs.list(ncs.by_dtype(nw.Int32) | ncs.by_dtype(nw.UInt32)), "b", "c" + ) + # inner InvertSelector + df.assert_selects(ncs.list(~ncs.all())) def test_selector_array(schema_nested_2: nw.Schema) -> None: diff --git a/tests/plan/utils.py b/tests/plan/utils.py index a78ed0a17e..23bfcff219 100644 --- a/tests/plan/utils.py +++ b/tests/plan/utils.py @@ -8,6 +8,7 @@ import narwhals as nw from narwhals import _plan as nwp from narwhals._plan import Expr, Selector, _expansion, _parse, expressions as ir +from narwhals._utils import qualified_type_name from tests.utils import assert_equal_data as _assert_equal_data pytest.importorskip("pyarrow") @@ -182,6 +183,25 @@ def assert_expr_ir_equal( assert lhs == rhs, f"\nlhs:\n {lhs!r}\n\nrhs:\n {rhs!r}" +def assert_not_selector(actual: Expr | Selector, /) -> None: + """Assert that `actual` was converted into an `Expr`.""" + assert isinstance(actual, Expr), ( + f"Didn't expect you to pass a {qualified_type_name(actual)!r} here, got: {actual!r}" + ) + assert not isinstance(actual, Selector), ( + f"This operation should have returned `Expr`, but got {qualified_type_name(actual)!r}\n{actual!r}" + ) + + +def is_expr_ir_equal(actual: Expr | ir.ExprIR, expected: Expr | ir.ExprIR, /) -> bool: + """Return True if `actual` is equivalent to `expected`. + + Note: + Prefer `assert_expr_ir_equal` unless you need a `bool` for branching. + """ + return _unwrap_ir(actual) == _unwrap_ir(expected) + + def named_ir(name: str, expr: nwp.Expr | ir.ExprIR, /) -> ir.NamedIR[ir.ExprIR]: """Helper constructor for test compare.""" return ir.NamedIR(expr=expr._ir if isinstance(expr, nwp.Expr) else expr, name=name)