diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d342428700..961e48e892 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,6 +79,8 @@ repos: entry: "self: Self" language: pygrep files: ^narwhals/ + # mypy needs `Self` for `ExprIR.dispatch` + exclude: ^narwhals/_plan/.*\.py - id: dtypes-import name: don't import from narwhals.dtypes (use `Version.dtypes` instead) entry: | diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py new file mode 100644 index 0000000000..044b786e28 --- /dev/null +++ b/narwhals/_plan/_dispatch.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import re +from collections.abc import Callable +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, final, overload + +from narwhals._plan._guards import is_function_expr +from narwhals._plan.compliant.typing import FrameT_contra, R_co +from narwhals._typing_compat import TypeVar + +if TYPE_CHECKING: + from typing_extensions import Never, TypeAlias + + from narwhals._plan.compliant.typing import Ctx + from narwhals._plan.expressions import ExprIR, FunctionExpr + from narwhals._plan.typing import ExprIRT, FunctionT + +__all__ = ["Dispatcher", "get_dispatch_name"] + + +Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]") +Node_contra = TypeVar( + "Node_contra", bound="ExprIR | FunctionExpr[Any]", contravariant=True +) +Raiser: TypeAlias = Callable[..., "Never"] + + +class Binder(Protocol[Node_contra]): + def __call__( + self, ctx: Ctx[FrameT_contra, R_co], / + ) -> BoundMethod[Node_contra, FrameT_contra, R_co]: ... + + +class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]): + def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ... + + +@final +class Dispatcher(Generic[Node]): + """Translate class definitions into error-wrapped method calls. + + Operates over `ExprIR` and `Function` nodes. + + By default, we dispatch to the compliant-level by calling a method that is the + **snake_case**-equivalent of the class name: + + class BinaryExpr(ExprIR): ... + + class CompliantExpr(Protocol): + def binary_expr(self, *args: Any): ... + """ + + __slots__ = ("_bind", "_name") + _bind: Binder[Node] + _name: str + + @property + def name(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"{type(self).__name__}<{self.name}>" + + def bind( + self, ctx: Ctx[FrameT_contra, R_co], / + ) -> BoundMethod[Node, FrameT_contra, R_co]: + """Retrieve the implementation of this expression from `ctx`. + + Binds an instance method, most commonly via: + + expr: CompliantExpr + method = getattr(expr, "method_name") + """ + try: + return self._bind(ctx) + except AttributeError: + raise self._not_implemented_error(ctx, "compliant") from None + + def __call__( + self, + node: Node, + ctx: Ctx[FrameT_contra, R_co], + frame: FrameT_contra, + name: str, + /, + ) -> R_co: + """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" + method = self.bind(ctx) + if result := method(node, frame, name): + return result + raise self._not_implemented_error(ctx, "context") + + @staticmethod + def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: + if not tp.__expr_ir_config__.allow_dispatch: + return Dispatcher._no_dispatch(tp) + return Dispatcher._from_type(tp) + + @staticmethod + def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: + return Dispatcher._from_type(tp) + + @staticmethod + @overload + def _from_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ... + @staticmethod + @overload + def _from_type(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: ... + @staticmethod + def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]: + obj = Dispatcher.__new__(Dispatcher) + obj._name = _method_name(tp) + getter = attrgetter(obj._name) + is_namespaced = tp.__expr_ir_config__.is_namespaced + obj._bind = _via_namespace(getter) if is_namespaced else getter + return obj + + @staticmethod + def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: + obj = Dispatcher.__new__(Dispatcher) + obj._name = tp.__name__ + obj._bind = obj._make_no_dispatch_error() + return obj + + def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]: + def _no_dispatch_error(node: Node, *_: Any) -> Never: + msg = ( + f"{self.name!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{node!r}" + ) + raise TypeError(msg) + + def getter(_: Any, /) -> Raiser: + return _no_dispatch_error + + return getter + + def _not_implemented_error( + self, ctx: object, /, missing: Literal["compliant", "context"] + ) -> NotImplementedError: + if missing == "context": + msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" + else: + msg = ( + f"`{self.name}` has not been implemented at the compliant-level.\n" + f"Hint: Try adding `CompliantExpr.{self.name}()` or `CompliantNamespace.{self.name}()`" + ) + return NotImplementedError(msg) + + +def _via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: + def _(ctx: Any, /) -> Any: + return getter(ctx.__narwhals_namespace__()) + + return _ + + +def _pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase string to snake_case. + + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() + + +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") + + +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" + + +def _method_name(tp: type[ExprIRT | FunctionT]) -> str: + config = tp.__expr_ir_config__ + name = config.override_name or _pascal_to_snake_case(tp.__name__) + return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name + + +def get_dispatch_name(expr: ExprIR, /) -> str: + """Return the synthesized method name for `expr`.""" + return ( + repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name + ) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index f0c9a08548..8f66f2d9fa 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -1,19 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, cast +from typing import TYPE_CHECKING, Generic +from narwhals._plan._dispatch import Dispatcher from narwhals._plan._guards import is_function_expr, is_literal from narwhals._plan._immutable import Immutable -from narwhals._plan.common import dispatch_getter, replace +from narwhals._plan.common import replace from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ExprIRT from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Iterator from typing import Any, ClassVar - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expr import Expr, Selector @@ -22,29 +23,6 @@ from narwhals._plan.typing import ExprIRT2, MapIR, Seq from narwhals.dtypes import DType - Incomplete: TypeAlias = "Any" - - -def _dispatch_generate( - tp: type[ExprIRT], / -) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - if not tp.__expr_ir_config__.allow_dispatch: - - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - msg = ( - f"{tp.__name__!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" - ) - raise TypeError(msg) - - return _ - getter = dispatch_getter(tp) - - def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" @@ -53,9 +31,7 @@ class ExprIR(Immutable): """Nested node names, in iteration order.""" __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] - ] + __expr_ir_dispatch__: ClassVar[Dispatcher[Self]] def __init_subclass__( cls: type[Self], @@ -69,13 +45,13 @@ def __init_subclass__( cls._child = child if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + cls.__expr_ir_dispatch__ = Dispatcher.from_expr_ir(cls) def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / ) -> R_co: - """Evaluate expression in `frame`, using `ctx` for implementation(s).""" - return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" + return self.__expr_ir_dispatch__(self, ctx, frame, name) def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import expr diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 332dbfc085..8b71a4dd8d 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -2,33 +2,21 @@ from typing import TYPE_CHECKING +from narwhals._plan._dispatch import Dispatcher from narwhals._plan._immutable import Immutable -from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace +from narwhals._plan.common import replace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: from typing import Any, Callable, ClassVar - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.expressions import ExprIR, FunctionExpr - from narwhals._plan.typing import Accessor, FunctionT + from narwhals._plan.typing import Accessor __all__ = ["Function", "HorizontalFunction"] -Incomplete: TypeAlias = "Any" - - -def _dispatch_generate_function( - tp: type[FunctionT], / -) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = dispatch_getter(tp) - - def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - class Function(Immutable): """Shared by expr functions and namespace functions. @@ -40,9 +28,7 @@ class Function(Immutable): FunctionOptions.default ) __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] - ] + __expr_ir_dispatch__: ClassVar[Dispatcher[FunctionExpr[Self]]] @property def function_options(self) -> FunctionOptions: @@ -72,10 +58,10 @@ def __init_subclass__( cls._function_options = staticmethod(options) if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) + cls.__expr_ir_dispatch__ = Dispatcher.from_function(cls) def __repr__(self) -> str: - return dispatch_method_name(type(self)) + return self.__expr_ir_dispatch__.name class HorizontalFunction( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index b519261c53..df57f781c1 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,9 +6,10 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir +from narwhals._plan._dispatch import get_dispatch_name from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options -from narwhals._plan.common import dispatch_method_name, temp +from narwhals._plan.common import temp from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy from narwhals._plan.expressions import aggregation as agg from narwhals._utils import Implementation @@ -132,11 +133,7 @@ def group_by_error( if reason == "too complex": msg = "Non-trivial complex aggregation found, which" else: - if is_function_expr(expr): - func_name = repr(expr.function) - else: - func_name = dispatch_method_name(type(expr)) - msg = f"`{func_name}()`" + msg = f"`{get_dispatch_name(expr)}()`" msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}" return InvalidOperationError(msg) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c29e8e8071..8ac2084034 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,11 +1,9 @@ from __future__ import annotations import datetime as dt -import re import sys from collections.abc import Iterable from decimal import Decimal -from operator import attrgetter from secrets import token_hex from typing import TYPE_CHECKING, cast, overload @@ -18,20 +16,13 @@ if TYPE_CHECKING: import reprlib from collections.abc import Iterator - from typing import Any, Callable, ClassVar, TypeVar + from typing import Any, ClassVar, TypeVar from typing_extensions import TypeIs from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.series import Series - from narwhals._plan.typing import ( - DTypeT, - ExprIRT, - FunctionT, - NonNestedDTypeT, - OneOrIterable, - Seq, - ) + from narwhals._plan.typing import DTypeT, NonNestedDTypeT, OneOrIterable, Seq from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -51,38 +42,6 @@ def replace(obj: T, /, **changes: Any) -> T: return func(obj, **changes) # type: ignore[no-any-return] -def pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. - - Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 - """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) - # Insert an underscore between a lowercase letter and an uppercase letter - return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - - -_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") -_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - - -def _re_repl_snake(match: re.Match[str], /) -> str: - return f"{match.group(1)}_{match.group(2)}" - - -def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: - config = tp.__expr_ir_config__ - name = config.override_name or pascal_to_snake_case(tp.__name__) - return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name - - -def dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: - getter = attrgetter(dispatch_method_name(tp)) - if tp.__expr_ir_config__.origin == "expr": - return getter - return lambda ctx: getter(ctx.__narwhals_namespace__()) - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 5defde7d61..e8d0dffbed 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -54,6 +54,9 @@ def name(self) -> str: ... def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... + def ewm_mean( + self, node: FunctionExpr[F.EwmMean], frame: FrameT_contra, name: str + ) -> Self: ... def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index abb873aa5e..25c07d7de7 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -11,6 +11,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, aggregation as agg from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct + from narwhals._plan.expressions.functions import EwmMean from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -58,6 +59,11 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... + def ewm_mean( + self, node: FunctionExpr[EwmMean], frame: FrameT_contra, name: str + ) -> Self: + return self._cast_float(node.input[0], frame, name) + def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 3b9ae41d11..92a563f586 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any from narwhals._plan._expr_ir import ExprIR -from narwhals._plan.common import pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: @@ -21,7 +20,7 @@ def is_scalar(self) -> bool: return True def __repr__(self) -> str: - return f"{self.expr!r}.{pascal_to_snake_case(type(self).__name__)}()" + return f"{self.expr!r}.{self.__expr_ir_dispatch__.name}()" def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 42d00bdbd7..4fa2f6cf6e 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -314,9 +314,9 @@ def __init__( super().__init__(**dict(input=input, function=function, options=options, **kwds)) def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.function.__expr_ir_dispatch__(self, ctx, frame, name) class RollingExpr(FunctionExpr[RollingT_co]): ... @@ -328,9 +328,9 @@ class AnonymousExpr( """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.__expr_ir_dispatch__(self, ctx, frame, name) class RangeExpr(FunctionExpr[RangeT_co]): diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index f8ee347cac..739654e4bf 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -2,7 +2,7 @@ import enum from itertools import repeat -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable @@ -11,14 +11,12 @@ import pyarrow.acero import pyarrow.compute as pc - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.arrow.typing import NullPlacement from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq from narwhals.typing import RankMethod -DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] - class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -285,8 +283,8 @@ def rolling_options( class _BaseIROptions(Immutable): - __slots__ = ("origin", "override_name") - origin: DispatchOrigin + __slots__ = ("is_namespaced", "override_name") + is_namespaced: bool override_name: str def __repr__(self) -> str: @@ -294,7 +292,7 @@ def __repr__(self) -> str: @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="") + return cls(is_namespaced=False, override_name="") @classmethod def renamed(cls, name: str, /) -> Self: @@ -306,9 +304,7 @@ def renamed(cls, name: str, /) -> Self: def namespaced(cls, override_name: str = "", /) -> Self: from narwhals._plan.common import replace - return replace( - cls.default(), origin="__narwhals_namespace__", override_name=override_name - ) + return replace(cls.default(), is_namespaced=True, override_name=override_name) class ExprIROptions(_BaseIROptions): @@ -317,11 +313,11 @@ class ExprIROptions(_BaseIROptions): @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="", allow_dispatch=True) + return cls(is_namespaced=False, override_name="", allow_dispatch=True) @staticmethod def no_dispatch() -> ExprIROptions: - return ExprIROptions(origin="expr", override_name="", allow_dispatch=False) + return ExprIROptions(is_namespaced=False, override_name="", allow_dispatch=False) class FunctionExprOptions(_BaseIROptions): @@ -331,7 +327,7 @@ class FunctionExprOptions(_BaseIROptions): @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="", accessor_name=None) + return cls(is_namespaced=False, override_name="", accessor_name=None) FEOptions = FunctionExprOptions diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 2d63cb99f2..e011e5c547 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -549,7 +549,8 @@ def test_protocol_expr() -> None: from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries - expr = ArrowExpr() + # NOTE: Intentionally leaving `ewm_mean` without a `not_implemented()` for another test + expr = ArrowExpr() # type: ignore[abstract] scalar = ArrowScalar() df = ArrowDataFrame() ser = ArrowSeries() diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py new file mode 100644 index 0000000000..db65d468a5 --- /dev/null +++ b/tests/plan/dispatch_test.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +pytest.importorskip("pyarrow") +import narwhals as nw +from narwhals import _plan as nwp +from narwhals._plan import expressions as ir, selectors as ncs +from narwhals._plan._dispatch import get_dispatch_name +from tests.plan.utils import assert_equal_data, dataframe, named_ir + +if TYPE_CHECKING: + import sys + + import pyarrow as pa + from typing_extensions import TypeAlias + + from narwhals._plan.dataframe import DataFrame + + if sys.version_info >= (3, 11): + _Flags: TypeAlias = "int | re.RegexFlag" + else: + _Flags: TypeAlias = int + + +@pytest.fixture +def data() -> dict[str, Any]: + return { + "a": [12.1, None, 4.0], + "b": [42, 10, None], + "c": [4, 5, 6], + "d": ["play", "swim", "walk"], + } + + +@pytest.fixture +def df(data: dict[str, Any]) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: + return dataframe(data) + + +def re_compile( + pattern: str, flags: _Flags = re.DOTALL | re.IGNORECASE +) -> re.Pattern[str]: + return re.compile(pattern, flags) + + +def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: + implemented_full = nwp.col("a").is_null() + forgot_to_expand = (named_ir("howdy", nwp.nth(3, 4).first()),) + aliased_after_expand: tuple[ir.NamedIR[Any]] = ( + ir.NamedIR.from_ir(ir.col("a").alias("b")), + ) + + assert_equal_data(df.select(implemented_full), {"a": [False, True, False]}) + + missing_backend = r"ewm_mean.+is not yet implemented for" + with pytest.raises(NotImplementedError, match=missing_backend): + df.select(nwp.col("c").ewm_mean()) + + missing_protocol = re_compile( + r"str\.contains.+has not been implemented.+compliant.+" + r"Hint.+try adding.+CompliantExpr\.str\.contains\(\)" + ) + with pytest.raises(NotImplementedError, match=missing_protocol): + df.select(nwp.col("d").str.contains("a")) + + with pytest.raises( + TypeError, + match=re_compile(r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first"), + ): + df._compliant.select(forgot_to_expand) + + bad = re.escape("col('a').alias('b')") + with pytest.raises(TypeError, match=re_compile(rf"Alias.+not.+appear.+got.+{bad}")): + df._compliant.select(aliased_after_expand) + + # Not a narwhals method, to make sure this doesn't allow arbitrary calls + with pytest.raises(AttributeError): + nwp.col("a").max().to_physical() # type: ignore[attr-defined] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a"), "col"), + (nwp.col("a").min().over("b"), "over"), + (nwp.col("a").first().over(order_by="b"), "over_ordered"), + (nwp.all_horizontal("a", "b", nwp.nth(4, 5, 6)), "all_horizontal"), + (nwp.int_range(10), "int_range"), + (nwp.col("a") + nwp.col("b") + 10, "binary_expr"), + (nwp.when(nwp.col("c")).then(5).when(nwp.col("d")).then(20), "ternary_expr"), + (nwp.col("a").cast(nw.String).str.starts_with("something"), ("str.starts_with")), + (nwp.mean("a"), "mean"), + (nwp.nth(1).first(), "first"), + (nwp.col("a").sum(), "sum"), + (nwp.col("a").drop_nulls().arg_min(), "arg_min"), + pytest.param(nwp.col("a").alias("b"), "Alias", id="no_dispatch-Alias"), + pytest.param(ncs.string(), "RootSelector", id="no_dispatch-RootSelector"), + ], +) +def test_dispatch_name(expr: nwp.Expr, expected: str) -> None: + assert get_dispatch_name(expr._ir) == expected