From 5fd705db647636de965dbeee02845c41fe4f9de6 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 15 Oct 2025 18:54:48 +0000 Subject: [PATCH 01/19] chore: Un-stash `DispatchGetter` idea --- narwhals/_plan/_expr_ir.py | 14 ++------ narwhals/_plan/_function.py | 4 +-- narwhals/_plan/common.py | 71 +++++++++++++++++++++++++++++++++---- 3 files changed, 69 insertions(+), 20 deletions(-) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index f0c9a08548..b6bd412acf 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -4,7 +4,7 @@ from narwhals._plan._guards import is_function_expr, is_literal from narwhals._plan._immutable import Immutable -from narwhals._plan.common import dispatch_getter, replace +from narwhals._plan.common import DispatchGetter, replace from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ExprIRT from narwhals.utils import Version @@ -28,17 +28,7 @@ def _dispatch_generate( tp: type[ExprIRT], / ) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - if not tp.__expr_ir_config__.allow_dispatch: - - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - msg = ( - f"{tp.__name__!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" - ) - raise TypeError(msg) - - return _ - getter = dispatch_getter(tp) + getter = DispatchGetter.from_expr_ir(tp) def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: return getter(ctx)(node, frame, name) diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 332dbfc085..f87c5251d1 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable -from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace +from narwhals._plan.common import DispatchGetter, dispatch_method_name, replace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: @@ -22,7 +22,7 @@ def _dispatch_generate_function( tp: type[FunctionT], / ) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = dispatch_getter(tp) + getter = DispatchGetter.from_function(tp) def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: return getter(ctx)(node, frame, name) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index c29e8e8071..fff55f10ce 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -20,7 +20,7 @@ from collections.abc import Iterator from typing import Any, Callable, ClassVar, TypeVar - from typing_extensions import TypeIs + from typing_extensions import Self, TypeIs from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.series import Series @@ -76,11 +76,70 @@ def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name -def dispatch_getter(tp: type[ExprIRT | FunctionT]) -> Callable[[Any], Any]: - getter = attrgetter(dispatch_method_name(tp)) - if tp.__expr_ir_config__.origin == "expr": - return getter - return lambda ctx: getter(ctx.__narwhals_namespace__()) +def _dispatch_via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: + def _(ctx: Any, /) -> Any: + return getter(ctx.__narwhals_namespace__()) + + return _ + + +class DispatchGetter: + __slots__ = ("_fn", "_name") + _fn: Callable[[Any], Any] + _name: str + + def __call__(self, ctx: Any, /) -> Any: + result = self._fn(ctx) + # this can be `None` iff the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` + if result is not None: + # but the issue I have is `result` exists, but returns `None` when called like `result(node, frame, name)` + return result + raise self._not_implemented_error(ctx) + + @classmethod + def no_dispatch(cls, tp: type[ExprIRT]) -> Self: + tp_name = tp.__name__ + obj = cls.__new__(cls) + obj._name = tp_name + + # NOTE: Temp weirdness until fixing the original issue with signature that this had (but never got triggered) + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: + raise obj._no_dispatch_error(ctx, node, name) + + obj._fn = lambda _ctx: _ + return obj + + @classmethod + def from_expr_ir(cls, tp: type[ExprIRT]) -> Self: + if not tp.__expr_ir_config__.allow_dispatch: + return cls.no_dispatch(tp) + return cls._from_configured_type(tp) + + @classmethod + def from_function(cls, tp: type[FunctionT]) -> Self: + return cls._from_configured_type(tp) + + @classmethod + def _from_configured_type(cls, tp: type[ExprIRT | FunctionT]) -> Self: + name = dispatch_method_name(tp) + getter = attrgetter(name) + origin = tp.__expr_ir_config__.origin + fn = getter if origin == "expr" else _dispatch_via_namespace(getter) + obj = cls.__new__(cls) + obj._fn = fn + obj._name = name + return obj + + def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: + msg = f"`{self._name}` is not yet implemented for {type(ctx).__name__!r}" + return NotImplementedError(msg) + + def _no_dispatch_error(self, ctx: Any, node: ExprIRT, name: str, /) -> TypeError: + msg = ( + f"{self._name!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" + ) + return TypeError(msg) def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: From d70ae4e3a0c91d510f054365ebdd02ab7e418c8b Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 15 Oct 2025 19:30:37 +0000 Subject: [PATCH 02/19] refactor: Move everything for the 97th time this api still needs a LOT of work --- narwhals/_plan/_dispatch.py | 126 ++++++++++++++++++++++ narwhals/_plan/_expr_ir.py | 18 +--- narwhals/_plan/_function.py | 18 +--- narwhals/_plan/common.py | 106 +----------------- narwhals/_plan/expressions/aggregation.py | 2 +- 5 files changed, 138 insertions(+), 132 deletions(-) create mode 100644 narwhals/_plan/_dispatch.py diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py new file mode 100644 index 0000000000..39e38c8abf --- /dev/null +++ b/narwhals/_plan/_dispatch.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import re +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + + from narwhals._plan.expressions import FunctionExpr + from narwhals._plan.typing import ExprIRT, FunctionT + +Incomplete: TypeAlias = "Any" + + +class DispatchGetter: + __slots__ = ("_fn", "_name") + _fn: Callable[[Any], Any] + _name: str + + def __call__(self, ctx: Any, /) -> Any: + result = self._fn(ctx) + # this can be `None` iff the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` + if result is not None: + # but the issue I have is `result` exists, but returns `None` when called like `result(node, frame, name)` + return result + raise self._not_implemented_error(ctx) + + @classmethod + def no_dispatch(cls, tp: type[ExprIRT]) -> Self: + tp_name = tp.__name__ + obj = cls.__new__(cls) + obj._name = tp_name + + # NOTE: Temp weirdness until fixing the original issue with signature that this had (but never got triggered) + def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: + raise obj._no_dispatch_error(ctx, node, name) + + obj._fn = lambda _ctx: _ + return obj + + @classmethod + def from_expr_ir(cls, tp: type[ExprIRT]) -> Self: + if not tp.__expr_ir_config__.allow_dispatch: + return cls.no_dispatch(tp) + return cls._from_configured_type(tp) + + @classmethod + def from_function(cls, tp: type[FunctionT]) -> Self: + return cls._from_configured_type(tp) + + @classmethod + def _from_configured_type(cls, tp: type[ExprIRT | FunctionT]) -> Self: + name = dispatch_method_name(tp) + getter = attrgetter(name) + origin = tp.__expr_ir_config__.origin + fn = getter if origin == "expr" else _dispatch_via_namespace(getter) + obj = cls.__new__(cls) + obj._fn = fn + obj._name = name + return obj + + def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: + msg = f"`{self._name}` is not yet implemented for {type(ctx).__name__!r}" + return NotImplementedError(msg) + + def _no_dispatch_error(self, ctx: Any, node: ExprIRT, name: str, /) -> TypeError: + msg = ( + f"{self._name!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" + ) + return TypeError(msg) + + +def _dispatch_via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: + def _(ctx: Any, /) -> Any: + return getter(ctx.__narwhals_namespace__()) + + return _ + + +def dispatch_generate( + tp: type[ExprIRT], / +) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: + getter = DispatchGetter.from_expr_ir(tp) + + def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + + +def dispatch_generate_function( + tp: type[FunctionT], / +) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: + getter = DispatchGetter.from_function(tp) + + def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: + return getter(ctx)(node, frame, name) + + return _ + + +def pascal_to_snake_case(s: str) -> str: + """Convert a PascalCase, camelCase string to snake_case. + + Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) + # Insert an underscore between a lowercase letter and an uppercase letter + return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() + + +_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") +_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") + + +def _re_repl_snake(match: re.Match[str], /) -> str: + return f"{match.group(1)}_{match.group(2)}" + + +def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: + config = tp.__expr_ir_config__ + name = config.override_name or pascal_to_snake_case(tp.__name__) + return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index b6bd412acf..0b37456e74 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -2,15 +2,16 @@ from typing import TYPE_CHECKING, Generic, cast +from narwhals._plan._dispatch import dispatch_generate from narwhals._plan._guards import is_function_expr, is_literal from narwhals._plan._immutable import Immutable -from narwhals._plan.common import DispatchGetter, replace +from narwhals._plan.common import replace from narwhals._plan.options import ExprIROptions from narwhals._plan.typing import ExprIRT from narwhals.utils import Version if TYPE_CHECKING: - from collections.abc import Callable, Iterator + from collections.abc import Iterator from typing import Any, ClassVar from typing_extensions import Self, TypeAlias @@ -25,17 +26,6 @@ Incomplete: TypeAlias = "Any" -def _dispatch_generate( - tp: type[ExprIRT], / -) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - getter = DispatchGetter.from_expr_ir(tp) - - def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" @@ -59,7 +49,7 @@ def __init_subclass__( cls._child = child if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls)) + cls.__expr_ir_dispatch__ = staticmethod(dispatch_generate(cls)) def dispatch( self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index f87c5251d1..865a487529 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -2,8 +2,9 @@ from typing import TYPE_CHECKING +from narwhals._plan._dispatch import dispatch_generate_function, dispatch_method_name from narwhals._plan._immutable import Immutable -from narwhals._plan.common import DispatchGetter, dispatch_method_name, replace +from narwhals._plan.common import replace from narwhals._plan.options import FEOptions, FunctionOptions if TYPE_CHECKING: @@ -12,24 +13,13 @@ from typing_extensions import Self, TypeAlias from narwhals._plan.expressions import ExprIR, FunctionExpr - from narwhals._plan.typing import Accessor, FunctionT + from narwhals._plan.typing import Accessor __all__ = ["Function", "HorizontalFunction"] Incomplete: TypeAlias = "Any" -def _dispatch_generate_function( - tp: type[FunctionT], / -) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = DispatchGetter.from_function(tp) - - def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - class Function(Immutable): """Shared by expr functions and namespace functions. @@ -72,7 +62,7 @@ def __init_subclass__( cls._function_options = staticmethod(options) if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls)) + cls.__expr_ir_dispatch__ = staticmethod(dispatch_generate_function(cls)) def __repr__(self) -> str: return dispatch_method_name(type(self)) diff --git a/narwhals/_plan/common.py b/narwhals/_plan/common.py index fff55f10ce..8ac2084034 100644 --- a/narwhals/_plan/common.py +++ b/narwhals/_plan/common.py @@ -1,11 +1,9 @@ from __future__ import annotations import datetime as dt -import re import sys from collections.abc import Iterable from decimal import Decimal -from operator import attrgetter from secrets import token_hex from typing import TYPE_CHECKING, cast, overload @@ -18,20 +16,13 @@ if TYPE_CHECKING: import reprlib from collections.abc import Iterator - from typing import Any, Callable, ClassVar, TypeVar + from typing import Any, ClassVar, TypeVar - from typing_extensions import Self, TypeIs + from typing_extensions import TypeIs from narwhals._plan.compliant.series import CompliantSeries from narwhals._plan.series import Series - from narwhals._plan.typing import ( - DTypeT, - ExprIRT, - FunctionT, - NonNestedDTypeT, - OneOrIterable, - Seq, - ) + from narwhals._plan.typing import DTypeT, NonNestedDTypeT, OneOrIterable, Seq from narwhals._utils import _StoresColumns from narwhals.typing import NonNestedDType, NonNestedLiteral @@ -51,97 +42,6 @@ def replace(obj: T, /, **changes: Any) -> T: return func(obj, **changes) # type: ignore[no-any-return] -def pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. - - Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 - """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) - # Insert an underscore between a lowercase letter and an uppercase letter - return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() - - -_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") -_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") - - -def _re_repl_snake(match: re.Match[str], /) -> str: - return f"{match.group(1)}_{match.group(2)}" - - -def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: - config = tp.__expr_ir_config__ - name = config.override_name or pascal_to_snake_case(tp.__name__) - return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name - - -def _dispatch_via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: - def _(ctx: Any, /) -> Any: - return getter(ctx.__narwhals_namespace__()) - - return _ - - -class DispatchGetter: - __slots__ = ("_fn", "_name") - _fn: Callable[[Any], Any] - _name: str - - def __call__(self, ctx: Any, /) -> Any: - result = self._fn(ctx) - # this can be `None` iff the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` - if result is not None: - # but the issue I have is `result` exists, but returns `None` when called like `result(node, frame, name)` - return result - raise self._not_implemented_error(ctx) - - @classmethod - def no_dispatch(cls, tp: type[ExprIRT]) -> Self: - tp_name = tp.__name__ - obj = cls.__new__(cls) - obj._name = tp_name - - # NOTE: Temp weirdness until fixing the original issue with signature that this had (but never got triggered) - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - raise obj._no_dispatch_error(ctx, node, name) - - obj._fn = lambda _ctx: _ - return obj - - @classmethod - def from_expr_ir(cls, tp: type[ExprIRT]) -> Self: - if not tp.__expr_ir_config__.allow_dispatch: - return cls.no_dispatch(tp) - return cls._from_configured_type(tp) - - @classmethod - def from_function(cls, tp: type[FunctionT]) -> Self: - return cls._from_configured_type(tp) - - @classmethod - def _from_configured_type(cls, tp: type[ExprIRT | FunctionT]) -> Self: - name = dispatch_method_name(tp) - getter = attrgetter(name) - origin = tp.__expr_ir_config__.origin - fn = getter if origin == "expr" else _dispatch_via_namespace(getter) - obj = cls.__new__(cls) - obj._fn = fn - obj._name = name - return obj - - def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: - msg = f"`{self._name}` is not yet implemented for {type(ctx).__name__!r}" - return NotImplementedError(msg) - - def _no_dispatch_error(self, ctx: Any, node: ExprIRT, name: str, /) -> TypeError: - msg = ( - f"{self._name!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" - ) - return TypeError(msg) - - def py_to_narwhals_dtype(obj: NonNestedLiteral, version: Version = Version.MAIN) -> DType: dtypes = version.dtypes mapping: dict[type[NonNestedLiteral], type[NonNestedDType]] = { diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index 3b9ae41d11..cb4b6c512d 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING, Any +from narwhals._plan._dispatch import pascal_to_snake_case from narwhals._plan._expr_ir import ExprIR -from narwhals._plan.common import pascal_to_snake_case from narwhals._plan.exceptions import agg_scalar_error if TYPE_CHECKING: From 23cbff03ce210525570602bca233fe4961242c2d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 15 Oct 2025 19:42:22 +0000 Subject: [PATCH 03/19] fix: update import --- narwhals/_plan/arrow/group_by.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index b519261c53..fb60ce5dca 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,9 +6,10 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir +from narwhals._plan._dispatch import dispatch_method_name from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options -from narwhals._plan.common import dispatch_method_name, temp +from narwhals._plan.common import temp from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy from narwhals._plan.expressions import aggregation as agg from narwhals._utils import Implementation From 4b08a6e3c07f52cd7e80988ac140b5ac67437e35 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:04:41 +0000 Subject: [PATCH 04/19] rewrite as `Dispatcher` much closer now, need to move to metaclass soon --- .pre-commit-config.yaml | 2 + narwhals/_plan/_dispatch.py | 134 +++++++++++++++++------------ narwhals/_plan/_expr_ir.py | 14 ++- narwhals/_plan/_function.py | 10 +-- narwhals/_plan/arrow/group_by.py | 8 +- narwhals/_plan/compliant/expr.py | 3 + narwhals/_plan/compliant/scalar.py | 6 ++ narwhals/_plan/expressions/expr.py | 8 +- tests/plan/compliant_test.py | 3 +- tests/plan/dispatch_test.py | 75 ++++++++++++++++ 10 files changed, 181 insertions(+), 82 deletions(-) create mode 100644 tests/plan/dispatch_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d342428700..961e48e892 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,6 +79,8 @@ repos: entry: "self: Self" language: pygrep files: ^narwhals/ + # mypy needs `Self` for `ExprIR.dispatch` + exclude: ^narwhals/_plan/.*\.py - id: dtypes-import name: don't import from narwhals.dtypes (use `Version.dtypes` instead) entry: | diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 39e38c8abf..99095876da 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -2,74 +2,109 @@ import re from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Generic, final, overload + +from narwhals._plan._guards import is_function_expr +from narwhals._typing_compat import TypeVar if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + from typing_extensions import Never, TypeAlias - from narwhals._plan.expressions import FunctionExpr + from narwhals._plan.expressions import ExprIR, FunctionExpr from narwhals._plan.typing import ExprIRT, FunctionT +__all__ = ["Dispatcher", "get_dispatch_name", "pascal_to_snake_case"] + Incomplete: TypeAlias = "Any" +Node = TypeVar("Node") -class DispatchGetter: - __slots__ = ("_fn", "_name") - _fn: Callable[[Any], Any] + +@final +class Dispatcher(Generic[Node]): + __slots__ = ("_method_getter", "_name") + _method_getter: Callable[[Any], Any] _name: str - def __call__(self, ctx: Any, /) -> Any: - result = self._fn(ctx) - # this can be `None` iff the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` - if result is not None: - # but the issue I have is `result` exists, but returns `None` when called like `result(node, frame, name)` + @property + def name(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"{type(self).__name__}<{self.name}>" + + def __call__( + self, ctx: Incomplete, node: Node, frame: Any, name: str, / + ) -> Incomplete: + # raises when the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` + # gives a more helpful error for things that are namespaced like `col("a").str.replace` + try: + bound_method = self._method_getter(ctx) + except AttributeError: + raise self._not_implemented_error(ctx) from None + + if result := bound_method(node, frame, name): return result + # here if is defined on `CompliantExpr`, but not on ctx raise self._not_implemented_error(ctx) @classmethod - def no_dispatch(cls, tp: type[ExprIRT]) -> Self: + def no_dispatch(cls: type[Dispatcher[Any]], tp: type[ExprIRT]) -> Dispatcher[ExprIRT]: tp_name = tp.__name__ obj = cls.__new__(cls) obj._name = tp_name # NOTE: Temp weirdness until fixing the original issue with signature that this had (but never got triggered) - def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any: - raise obj._no_dispatch_error(ctx, node, name) - - obj._fn = lambda _ctx: _ + obj._method_getter = lambda _ctx: obj._no_dispatch_error return obj @classmethod - def from_expr_ir(cls, tp: type[ExprIRT]) -> Self: + def from_expr_ir( + cls: type[Dispatcher[Any]], tp: type[ExprIRT] + ) -> Dispatcher[ExprIRT]: if not tp.__expr_ir_config__.allow_dispatch: return cls.no_dispatch(tp) - return cls._from_configured_type(tp) + return Dispatcher._from_configured_type(tp) @classmethod - def from_function(cls, tp: type[FunctionT]) -> Self: - return cls._from_configured_type(tp) - - @classmethod - def _from_configured_type(cls, tp: type[ExprIRT | FunctionT]) -> Self: - name = dispatch_method_name(tp) + def from_function( + cls: type[Dispatcher[Any]], tp: type[FunctionT] + ) -> Dispatcher[FunctionExpr[FunctionT]]: + return Dispatcher._from_configured_type(tp) + + @staticmethod + @overload + def _from_configured_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ... + + @staticmethod + @overload + def _from_configured_type( + tp: type[FunctionT], / + ) -> Dispatcher[FunctionExpr[FunctionT]]: ... + + @staticmethod + def _from_configured_type( + tp: type[ExprIRT | FunctionT], / + ) -> Dispatcher[ExprIRT] | Dispatcher[FunctionExpr[FunctionT]]: + name = _dispatch_method_name(tp) getter = attrgetter(name) origin = tp.__expr_ir_config__.origin fn = getter if origin == "expr" else _dispatch_via_namespace(getter) - obj = cls.__new__(cls) - obj._fn = fn + obj = Dispatcher.__new__(Dispatcher) + obj._method_getter = fn obj._name = name return obj - def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: - msg = f"`{self._name}` is not yet implemented for {type(ctx).__name__!r}" - return NotImplementedError(msg) - - def _no_dispatch_error(self, ctx: Any, node: ExprIRT, name: str, /) -> TypeError: + def _no_dispatch_error(self, node: Node, frame: Any, name: str, /) -> Never: msg = ( - f"{self._name!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}" + f"{self.name!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{node!r}" ) - return TypeError(msg) + raise TypeError(msg) + + def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: + msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" + return NotImplementedError(msg) def _dispatch_via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: @@ -79,28 +114,6 @@ def _(ctx: Any, /) -> Any: return _ -def dispatch_generate( - tp: type[ExprIRT], / -) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]: - getter = DispatchGetter.from_expr_ir(tp) - - def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - -def dispatch_generate_function( - tp: type[FunctionT], / -) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]: - getter = DispatchGetter.from_function(tp) - - def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any: - return getter(ctx)(node, frame, name) - - return _ - - def pascal_to_snake_case(s: str) -> str: """Convert a PascalCase, camelCase string to snake_case. @@ -120,7 +133,14 @@ def _re_repl_snake(match: re.Match[str], /) -> str: return f"{match.group(1)}_{match.group(2)}" -def dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: +def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: config = tp.__expr_ir_config__ name = config.override_name or pascal_to_snake_case(tp.__name__) return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name + + +def get_dispatch_name(expr: ExprIR, /) -> str: + """Return the synthesized method name for `expr`.""" + return ( + repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name + ) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 0b37456e74..4f14f7f988 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, cast +from typing import TYPE_CHECKING, Generic -from narwhals._plan._dispatch import dispatch_generate +from narwhals._plan._dispatch import Dispatcher from narwhals._plan._guards import is_function_expr, is_literal from narwhals._plan._immutable import Immutable from narwhals._plan.common import replace @@ -33,9 +33,7 @@ class ExprIR(Immutable): """Nested node names, in iteration order.""" __expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, Self, Incomplete, str], Incomplete] - ] + __expr_ir_dispatch__: ClassVar[Dispatcher[Self]] def __init_subclass__( cls: type[Self], @@ -49,13 +47,13 @@ def __init_subclass__( cls._child = child if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(dispatch_generate(cls)) + cls.__expr_ir_dispatch__ = Dispatcher.from_expr_ir(cls) def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / ) -> R_co: """Evaluate expression in `frame`, using `ctx` for implementation(s).""" - return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.__expr_ir_dispatch__(ctx, self, frame, name) # type: ignore[no-any-return] def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import expr diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index 865a487529..d319f68493 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from narwhals._plan._dispatch import dispatch_generate_function, dispatch_method_name +from narwhals._plan._dispatch import Dispatcher from narwhals._plan._immutable import Immutable from narwhals._plan.common import replace from narwhals._plan.options import FEOptions, FunctionOptions @@ -30,9 +30,7 @@ class Function(Immutable): FunctionOptions.default ) __expr_ir_config__: ClassVar[FEOptions] = FEOptions.default() - __expr_ir_dispatch__: ClassVar[ - staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete] - ] + __expr_ir_dispatch__: ClassVar[Dispatcher[FunctionExpr[Self]]] @property def function_options(self) -> FunctionOptions: @@ -62,10 +60,10 @@ def __init_subclass__( cls._function_options = staticmethod(options) if config: cls.__expr_ir_config__ = config - cls.__expr_ir_dispatch__ = staticmethod(dispatch_generate_function(cls)) + cls.__expr_ir_dispatch__ = Dispatcher.from_function(cls) def __repr__(self) -> str: - return dispatch_method_name(type(self)) + return self.__expr_ir_dispatch__.name class HorizontalFunction( diff --git a/narwhals/_plan/arrow/group_by.py b/narwhals/_plan/arrow/group_by.py index fb60ce5dca..df57f781c1 100644 --- a/narwhals/_plan/arrow/group_by.py +++ b/narwhals/_plan/arrow/group_by.py @@ -6,7 +6,7 @@ import pyarrow.compute as pc # ignore-banned-import from narwhals._plan import expressions as ir -from narwhals._plan._dispatch import dispatch_method_name +from narwhals._plan._dispatch import get_dispatch_name from narwhals._plan._guards import is_agg_expr, is_function_expr from narwhals._plan.arrow import acero, functions as fn, options from narwhals._plan.common import temp @@ -133,11 +133,7 @@ def group_by_error( if reason == "too complex": msg = "Non-trivial complex aggregation found, which" else: - if is_function_expr(expr): - func_name = repr(expr.function) - else: - func_name = dispatch_method_name(type(expr)) - msg = f"`{func_name}()`" + msg = f"`{get_dispatch_name(expr)}()`" msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}" return InvalidOperationError(msg) diff --git a/narwhals/_plan/compliant/expr.py b/narwhals/_plan/compliant/expr.py index 5defde7d61..e8d0dffbed 100644 --- a/narwhals/_plan/compliant/expr.py +++ b/narwhals/_plan/compliant/expr.py @@ -54,6 +54,9 @@ def name(self) -> str: ... def abs(self, node: FunctionExpr[F.Abs], frame: FrameT_contra, name: str) -> Self: ... def binary_expr(self, node: BinaryExpr, frame: FrameT_contra, name: str) -> Self: ... def cast(self, node: ir.Cast, frame: FrameT_contra, name: str) -> Self: ... + def ewm_mean( + self, node: FunctionExpr[F.EwmMean], frame: FrameT_contra, name: str + ) -> Self: ... def fill_null( self, node: FunctionExpr[F.FillNull], frame: FrameT_contra, name: str ) -> Self: ... diff --git a/narwhals/_plan/compliant/scalar.py b/narwhals/_plan/compliant/scalar.py index abb873aa5e..25c07d7de7 100644 --- a/narwhals/_plan/compliant/scalar.py +++ b/narwhals/_plan/compliant/scalar.py @@ -11,6 +11,7 @@ from narwhals._plan import expressions as ir from narwhals._plan.expressions import FunctionExpr, aggregation as agg from narwhals._plan.expressions.boolean import IsFirstDistinct, IsLastDistinct + from narwhals._plan.expressions.functions import EwmMean from narwhals._utils import Version from narwhals.typing import IntoDType, PythonLiteral @@ -58,6 +59,11 @@ def count(self, node: agg.Count, frame: FrameT_contra, name: str) -> Self: """Returns 0 if null, else 1.""" ... + def ewm_mean( + self, node: FunctionExpr[EwmMean], frame: FrameT_contra, name: str + ) -> Self: + return self._cast_float(node.input[0], frame, name) + def first(self, node: agg.First, frame: FrameT_contra, name: str) -> Self: return self._with_evaluated(self._evaluated, name) diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 42d00bdbd7..a020b48ac3 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -314,9 +314,9 @@ def __init__( super().__init__(**dict(input=input, function=function, options=options, **kwds)) def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.function.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.function.__expr_ir_dispatch__(ctx, self, frame, name) # type: ignore[no-any-return] class RollingExpr(FunctionExpr[RollingT_co]): ... @@ -328,9 +328,9 @@ class AnonymousExpr( """https://github.com/pola-rs/polars/blob/dafd0a2d0e32b52bcfa4273bffdd6071a0d5977a/crates/polars-plan/src/dsl/expr.rs#L158-L166.""" def dispatch( - self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str + self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.__expr_ir_dispatch__(ctx, t.cast("Self", self), frame, name) # type: ignore[no-any-return] + return self.__expr_ir_dispatch__(ctx, self, frame, name) # type: ignore[no-any-return] class RangeExpr(FunctionExpr[RangeT_co]): diff --git a/tests/plan/compliant_test.py b/tests/plan/compliant_test.py index 2d63cb99f2..e011e5c547 100644 --- a/tests/plan/compliant_test.py +++ b/tests/plan/compliant_test.py @@ -549,7 +549,8 @@ def test_protocol_expr() -> None: from narwhals._plan.arrow.expr import ArrowExpr, ArrowScalar from narwhals._plan.arrow.series import ArrowSeries - expr = ArrowExpr() + # NOTE: Intentionally leaving `ewm_mean` without a `not_implemented()` for another test + expr = ArrowExpr() # type: ignore[abstract] scalar = ArrowScalar() df = ArrowDataFrame() ser = ArrowSeries() diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py new file mode 100644 index 0000000000..e28be6e10c --- /dev/null +++ b/tests/plan/dispatch_test.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Any + +import pytest + +pytest.importorskip("pyarrow") +import narwhals as nw +from narwhals import _plan as nwp +from narwhals._plan._dispatch import get_dispatch_name +from tests.plan.utils import assert_equal_data, dataframe, named_ir + +if TYPE_CHECKING: + import pyarrow as pa + + from narwhals._plan.dataframe import DataFrame + + +@pytest.fixture +def data() -> dict[str, Any]: + return { + "a": [12.1, None, 4.0], + "b": [42, 10, None], + "c": [4, 5, 6], + "d": ["play", "swim", "walk"], + } + + +@pytest.fixture +def df(data: dict[str, Any]) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: + return dataframe(data) + + +def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: + implemented_full = nwp.col("a").is_null() + only_at_compliant_level = nwp.col("c").ewm_mean() + only_at_narwhals_level = nwp.col("d").str.contains("a") + forgot_to_expand = (named_ir("howdy", nwp.nth(3, 4).first()),) + pattern_expand = re.compile( + r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first", + re.DOTALL | re.IGNORECASE, + ) + + assert_equal_data(df.select(implemented_full), {"a": [False, True, False]}) + + with pytest.raises(NotImplementedError, match=r"ewm_mean"): + df.select(only_at_compliant_level) + + with pytest.raises(NotImplementedError, match=r"str\.contains"): + df.select(only_at_narwhals_level) + + with pytest.raises(TypeError, match=pattern_expand): + df._compliant.select(forgot_to_expand) + + # Not a narwhals method, to make sure this doesn't allow arbitrary calls + with pytest.raises(AttributeError): + nwp.col("a").max().to_physical() # type: ignore[attr-defined] + + +@pytest.mark.parametrize( + ("expr", "expected"), + [ + (nwp.col("a"), "col"), + (nwp.col("a").min().over("b"), "over"), + (nwp.col("a").first().over(order_by="b"), "over_ordered"), + (nwp.all_horizontal("a", "b", nwp.nth(4, 5, 6)), "all_horizontal"), + (nwp.int_range(10), "int_range"), + (nwp.col("a") + nwp.col("b") + 10, "binary_expr"), + (nwp.when(nwp.col("c")).then(5).when(nwp.col("d")).then(20), "ternary_expr"), + (nwp.col("a").cast(nw.String).str.starts_with("something"), ("str.starts_with")), + ], +) +def test_dispatch_name(expr: nwp.Expr, expected: str) -> None: + assert get_dispatch_name(expr._ir) == expected From 02d7e41b555c77da78e2bd9f812be0c6fbad9df8 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 16 Oct 2025 17:19:55 +0000 Subject: [PATCH 05/19] refactor: un-expose `_pascal_to_snake_case` No need for re-computing when the name is stored on the class --- narwhals/_plan/_dispatch.py | 6 +++--- narwhals/_plan/expressions/aggregation.py | 3 +-- tests/plan/dispatch_test.py | 4 ++++ 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 99095876da..bd73d9f316 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -13,7 +13,7 @@ from narwhals._plan.expressions import ExprIR, FunctionExpr from narwhals._plan.typing import ExprIRT, FunctionT -__all__ = ["Dispatcher", "get_dispatch_name", "pascal_to_snake_case"] +__all__ = ["Dispatcher", "get_dispatch_name"] Incomplete: TypeAlias = "Any" @@ -114,7 +114,7 @@ def _(ctx: Any, /) -> Any: return _ -def pascal_to_snake_case(s: str) -> str: +def _pascal_to_snake_case(s: str) -> str: """Convert a PascalCase, camelCase string to snake_case. Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 @@ -135,7 +135,7 @@ def _re_repl_snake(match: re.Match[str], /) -> str: def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: config = tp.__expr_ir_config__ - name = config.override_name or pascal_to_snake_case(tp.__name__) + name = config.override_name or _pascal_to_snake_case(tp.__name__) return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name diff --git a/narwhals/_plan/expressions/aggregation.py b/narwhals/_plan/expressions/aggregation.py index cb4b6c512d..92a563f586 100644 --- a/narwhals/_plan/expressions/aggregation.py +++ b/narwhals/_plan/expressions/aggregation.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING, Any -from narwhals._plan._dispatch import pascal_to_snake_case from narwhals._plan._expr_ir import ExprIR from narwhals._plan.exceptions import agg_scalar_error @@ -21,7 +20,7 @@ def is_scalar(self) -> bool: return True def __repr__(self) -> str: - return f"{self.expr!r}.{pascal_to_snake_case(type(self).__name__)}()" + return f"{self.expr!r}.{self.__expr_ir_dispatch__.name}()" def iter_output_name(self) -> Iterator[ExprIR]: yield from self.expr.iter_output_name() diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index e28be6e10c..631676692c 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -69,6 +69,10 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: (nwp.col("a") + nwp.col("b") + 10, "binary_expr"), (nwp.when(nwp.col("c")).then(5).when(nwp.col("d")).then(20), "ternary_expr"), (nwp.col("a").cast(nw.String).str.starts_with("something"), ("str.starts_with")), + (nwp.mean("a"), "mean"), + (nwp.nth(1).first(), "first"), + (nwp.col("a").sum(), "sum"), + (nwp.col("a").drop_nulls().arg_min(), "arg_min"), ], ) def test_dispatch_name(expr: nwp.Expr, expected: str) -> None: From 40504f019b60bd67f82d20a0e563943b46b1eadf Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:16:03 +0000 Subject: [PATCH 06/19] less-bad no_dispatch --- narwhals/_plan/_dispatch.py | 45 +++++++++++++++++++++---------------- tests/plan/dispatch_test.py | 14 ++++++++++++ 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index bd73d9f316..1e6544fa34 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -20,10 +20,14 @@ Node = TypeVar("Node") +Getter: TypeAlias = Callable[[Any], Any] +Raiser: TypeAlias = Callable[..., "Never"] + + @final class Dispatcher(Generic[Node]): __slots__ = ("_method_getter", "_name") - _method_getter: Callable[[Any], Any] + _method_getter: Getter _name: str @property @@ -48,22 +52,12 @@ def __call__( # here if is defined on `CompliantExpr`, but not on ctx raise self._not_implemented_error(ctx) - @classmethod - def no_dispatch(cls: type[Dispatcher[Any]], tp: type[ExprIRT]) -> Dispatcher[ExprIRT]: - tp_name = tp.__name__ - obj = cls.__new__(cls) - obj._name = tp_name - - # NOTE: Temp weirdness until fixing the original issue with signature that this had (but never got triggered) - obj._method_getter = lambda _ctx: obj._no_dispatch_error - return obj - @classmethod def from_expr_ir( cls: type[Dispatcher[Any]], tp: type[ExprIRT] ) -> Dispatcher[ExprIRT]: if not tp.__expr_ir_config__.allow_dispatch: - return cls.no_dispatch(tp) + return cls._no_dispatch(tp) return Dispatcher._from_configured_type(tp) @classmethod @@ -95,19 +89,32 @@ def _from_configured_type( obj._name = name return obj - def _no_dispatch_error(self, node: Node, frame: Any, name: str, /) -> Never: - msg = ( - f"{self.name!r} should not appear at the compliant-level.\n\n" - f"Make sure to expand all expressions first, got:\n{node!r}" - ) - raise TypeError(msg) + @staticmethod + def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: + obj = Dispatcher.__new__(Dispatcher) + obj._name = tp.__name__ + obj._method_getter = obj._make_no_dispatch_error() + return obj + + def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]: + def _no_dispatch_error(node: Node, *_: Any) -> Never: + msg = ( + f"{self.name!r} should not appear at the compliant-level.\n\n" + f"Make sure to expand all expressions first, got:\n{node!r}" + ) + raise TypeError(msg) + + def getter(_: Any, /) -> Raiser: + return _no_dispatch_error + + return getter def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" return NotImplementedError(msg) -def _dispatch_via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: +def _dispatch_via_namespace(getter: Getter, /) -> Getter: def _(ctx: Any, /) -> Any: return getter(ctx.__narwhals_namespace__()) diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index 631676692c..4d3fcb2c79 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -8,6 +8,7 @@ pytest.importorskip("pyarrow") import narwhals as nw from narwhals import _plan as nwp +from narwhals._plan import expressions as ir, selectors as ncs from narwhals._plan._dispatch import get_dispatch_name from tests.plan.utils import assert_equal_data, dataframe, named_ir @@ -37,10 +38,18 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: only_at_compliant_level = nwp.col("c").ewm_mean() only_at_narwhals_level = nwp.col("d").str.contains("a") forgot_to_expand = (named_ir("howdy", nwp.nth(3, 4).first()),) + aliased_after_expand: tuple[ir.NamedIR[Any]] = ( + ir.NamedIR.from_ir(ir.col("a").alias("b")), + ) + pattern_expand = re.compile( r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first", re.DOTALL | re.IGNORECASE, ) + bad = re.escape("col('a').alias('b')") + pattern_aliased_after_expand = re.compile( + rf"Alias.+not.+appear.+got.+{bad}", re.DOTALL | re.IGNORECASE + ) assert_equal_data(df.select(implemented_full), {"a": [False, True, False]}) @@ -53,6 +62,9 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: with pytest.raises(TypeError, match=pattern_expand): df._compliant.select(forgot_to_expand) + with pytest.raises(TypeError, match=pattern_aliased_after_expand): + df._compliant.select(aliased_after_expand) + # Not a narwhals method, to make sure this doesn't allow arbitrary calls with pytest.raises(AttributeError): nwp.col("a").max().to_physical() # type: ignore[attr-defined] @@ -73,6 +85,8 @@ def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: (nwp.nth(1).first(), "first"), (nwp.col("a").sum(), "sum"), (nwp.col("a").drop_nulls().arg_min(), "arg_min"), + pytest.param(nwp.col("a").alias("b"), "Alias", id="no_dispatch-Alias"), + pytest.param(ncs.string(), "RootSelector", id="no_dispatch-RootSelector"), ], ) def test_dispatch_name(expr: nwp.Expr, expected: str) -> None: From c07972b49576f69451770bb832c40d2d4e4056a0 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:37:08 +0000 Subject: [PATCH 07/19] tweak tweak tweak --- narwhals/_plan/_dispatch.py | 29 ++++++++++++----------------- narwhals/_plan/options.py | 4 ++++ 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 1e6544fa34..a636d42016 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -52,18 +52,14 @@ def __call__( # here if is defined on `CompliantExpr`, but not on ctx raise self._not_implemented_error(ctx) - @classmethod - def from_expr_ir( - cls: type[Dispatcher[Any]], tp: type[ExprIRT] - ) -> Dispatcher[ExprIRT]: + @staticmethod + def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: if not tp.__expr_ir_config__.allow_dispatch: - return cls._no_dispatch(tp) + return Dispatcher._no_dispatch(tp) return Dispatcher._from_configured_type(tp) - @classmethod - def from_function( - cls: type[Dispatcher[Any]], tp: type[FunctionT] - ) -> Dispatcher[FunctionExpr[FunctionT]]: + @staticmethod + def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: return Dispatcher._from_configured_type(tp) @staticmethod @@ -76,17 +72,16 @@ def _from_configured_type( tp: type[FunctionT], / ) -> Dispatcher[FunctionExpr[FunctionT]]: ... + # TODO @dangotbanned: Can this be done without overloads? @staticmethod def _from_configured_type( tp: type[ExprIRT | FunctionT], / ) -> Dispatcher[ExprIRT] | Dispatcher[FunctionExpr[FunctionT]]: - name = _dispatch_method_name(tp) - getter = attrgetter(name) - origin = tp.__expr_ir_config__.origin - fn = getter if origin == "expr" else _dispatch_via_namespace(getter) obj = Dispatcher.__new__(Dispatcher) - obj._method_getter = fn - obj._name = name + obj._name = _method_name(tp) + getter = attrgetter(obj._name) + is_namespaced = tp.__expr_ir_config__.is_namespaced + obj._method_getter = _via_namespace(getter) if is_namespaced else getter return obj @staticmethod @@ -114,7 +109,7 @@ def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: return NotImplementedError(msg) -def _dispatch_via_namespace(getter: Getter, /) -> Getter: +def _via_namespace(getter: Getter, /) -> Getter: def _(ctx: Any, /) -> Any: return getter(ctx.__narwhals_namespace__()) @@ -140,7 +135,7 @@ def _re_repl_snake(match: re.Match[str], /) -> str: return f"{match.group(1)}_{match.group(2)}" -def _dispatch_method_name(tp: type[ExprIRT | FunctionT]) -> str: +def _method_name(tp: type[ExprIRT | FunctionT]) -> str: config = tp.__expr_ir_config__ name = config.override_name or _pascal_to_snake_case(tp.__name__) return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index f8ee347cac..0b8692e92c 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -310,6 +310,10 @@ def namespaced(cls, override_name: str = "", /) -> Self: cls.default(), origin="__narwhals_namespace__", override_name=override_name ) + @property + def is_namespaced(self) -> bool: + return self.origin == "__narwhals_namespace__" + class ExprIROptions(_BaseIROptions): __slots__ = ("allow_dispatch",) From b26e833089f6dfb4b83a15d75eee81d891fa764a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 16 Oct 2025 18:44:39 +0000 Subject: [PATCH 08/19] refactor(typing): Less no-any-return --- narwhals/_plan/_dispatch.py | 12 +++++++++--- narwhals/_plan/_expr_ir.py | 2 +- narwhals/_plan/expressions/expr.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index a636d42016..e2ff1b1ccc 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from typing_extensions import Never, TypeAlias + from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expressions import ExprIR, FunctionExpr from narwhals._plan.typing import ExprIRT, FunctionT @@ -38,8 +39,13 @@ def __repr__(self) -> str: return f"{type(self).__name__}<{self.name}>" def __call__( - self, ctx: Incomplete, node: Node, frame: Any, name: str, / - ) -> Incomplete: + self, + ctx: Ctx[FrameT_contra, R_co], + node: Node, + frame: FrameT_contra, + name: str, + /, + ) -> R_co: # raises when the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` # gives a more helpful error for things that are namespaced like `col("a").str.replace` try: @@ -48,7 +54,7 @@ def __call__( raise self._not_implemented_error(ctx) from None if result := bound_method(node, frame, name): - return result + return result # type: ignore[no-any-return] # here if is defined on `CompliantExpr`, but not on ctx raise self._not_implemented_error(ctx) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 4f14f7f988..40ba704e13 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -53,7 +53,7 @@ def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / ) -> R_co: """Evaluate expression in `frame`, using `ctx` for implementation(s).""" - return self.__expr_ir_dispatch__(ctx, self, frame, name) # type: ignore[no-any-return] + return self.__expr_ir_dispatch__(ctx, self, frame, name) def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import expr diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index a020b48ac3..82736e1092 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -316,7 +316,7 @@ def __init__( def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.function.__expr_ir_dispatch__(ctx, self, frame, name) # type: ignore[no-any-return] + return self.function.__expr_ir_dispatch__(ctx, self, frame, name) class RollingExpr(FunctionExpr[RollingT_co]): ... @@ -330,7 +330,7 @@ class AnonymousExpr( def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.__expr_ir_dispatch__(ctx, self, frame, name) # type: ignore[no-any-return] + return self.__expr_ir_dispatch__(ctx, self, frame, name) class RangeExpr(FunctionExpr[RangeT_co]): From cc73b4a6c11a57a845e5c4730d9833b85e6efd62 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 19:07:43 +0000 Subject: [PATCH 09/19] refactor: `origin` -> `is_namespaced` --- narwhals/_plan/options.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/narwhals/_plan/options.py b/narwhals/_plan/options.py index 0b8692e92c..739654e4bf 100644 --- a/narwhals/_plan/options.py +++ b/narwhals/_plan/options.py @@ -2,7 +2,7 @@ import enum from itertools import repeat -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from narwhals._plan._immutable import Immutable @@ -11,14 +11,12 @@ import pyarrow.acero import pyarrow.compute as pc - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.arrow.typing import NullPlacement from narwhals._plan.typing import Accessor, OneOrIterable, Order, Seq from narwhals.typing import RankMethod -DispatchOrigin: TypeAlias = Literal["expr", "__narwhals_namespace__"] - class FunctionFlags(enum.Flag): ALLOW_GROUP_AWARE = 1 << 0 @@ -285,8 +283,8 @@ def rolling_options( class _BaseIROptions(Immutable): - __slots__ = ("origin", "override_name") - origin: DispatchOrigin + __slots__ = ("is_namespaced", "override_name") + is_namespaced: bool override_name: str def __repr__(self) -> str: @@ -294,7 +292,7 @@ def __repr__(self) -> str: @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="") + return cls(is_namespaced=False, override_name="") @classmethod def renamed(cls, name: str, /) -> Self: @@ -306,13 +304,7 @@ def renamed(cls, name: str, /) -> Self: def namespaced(cls, override_name: str = "", /) -> Self: from narwhals._plan.common import replace - return replace( - cls.default(), origin="__narwhals_namespace__", override_name=override_name - ) - - @property - def is_namespaced(self) -> bool: - return self.origin == "__narwhals_namespace__" + return replace(cls.default(), is_namespaced=True, override_name=override_name) class ExprIROptions(_BaseIROptions): @@ -321,11 +313,11 @@ class ExprIROptions(_BaseIROptions): @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="", allow_dispatch=True) + return cls(is_namespaced=False, override_name="", allow_dispatch=True) @staticmethod def no_dispatch() -> ExprIROptions: - return ExprIROptions(origin="expr", override_name="", allow_dispatch=False) + return ExprIROptions(is_namespaced=False, override_name="", allow_dispatch=False) class FunctionExprOptions(_BaseIROptions): @@ -335,7 +327,7 @@ class FunctionExprOptions(_BaseIROptions): @classmethod def default(cls) -> Self: - return cls(origin="expr", override_name="", accessor_name=None) + return cls(is_namespaced=False, override_name="", accessor_name=None) FEOptions = FunctionExprOptions From 04225e7a9759972d4a1866226c6a3ccc7679049e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 19:10:40 +0000 Subject: [PATCH 10/19] chore: Update TODO --- narwhals/_plan/_dispatch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index e2ff1b1ccc..b2a51acd98 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -16,15 +16,16 @@ __all__ = ["Dispatcher", "get_dispatch_name"] -Incomplete: TypeAlias = "Any" Node = TypeVar("Node") - - Getter: TypeAlias = Callable[[Any], Any] Raiser: TypeAlias = Callable[..., "Never"] +# TODO @dangotbanned: Can this be done without overloads? +# TODO @dangotbanned: Clean up `__call__` comments +# TODO @dangotbanned: Bound `Node`? +# TODO @dangotbanned: Rename `_from_configured_type` @final class Dispatcher(Generic[Node]): __slots__ = ("_method_getter", "_name") From deaf691df640571308e1f2748b8ffae4166af79d Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 21:45:20 +0000 Subject: [PATCH 11/19] refactor: `_from_configured_type` -> `_from_type` --- narwhals/_plan/_dispatch.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index b2a51acd98..a1e3a8148d 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -22,10 +22,8 @@ Raiser: TypeAlias = Callable[..., "Never"] -# TODO @dangotbanned: Can this be done without overloads? # TODO @dangotbanned: Clean up `__call__` comments # TODO @dangotbanned: Bound `Node`? -# TODO @dangotbanned: Rename `_from_configured_type` @final class Dispatcher(Generic[Node]): __slots__ = ("_method_getter", "_name") @@ -63,27 +61,20 @@ def __call__( def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: if not tp.__expr_ir_config__.allow_dispatch: return Dispatcher._no_dispatch(tp) - return Dispatcher._from_configured_type(tp) + return Dispatcher._from_type(tp) @staticmethod def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: - return Dispatcher._from_configured_type(tp) + return Dispatcher._from_type(tp) @staticmethod @overload - def _from_configured_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ... - + def _from_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ... @staticmethod @overload - def _from_configured_type( - tp: type[FunctionT], / - ) -> Dispatcher[FunctionExpr[FunctionT]]: ... - - # TODO @dangotbanned: Can this be done without overloads? + def _from_type(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: ... @staticmethod - def _from_configured_type( - tp: type[ExprIRT | FunctionT], / - ) -> Dispatcher[ExprIRT] | Dispatcher[FunctionExpr[FunctionT]]: + def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]: obj = Dispatcher.__new__(Dispatcher) obj._name = _method_name(tp) getter = attrgetter(obj._name) From 51e13f7467bb2d319ce93e9241a66eb06a4d2024 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 21:51:26 +0000 Subject: [PATCH 12/19] chore(typing): Give `Node` an upper bound --- narwhals/_plan/_dispatch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index a1e3a8148d..eb0301e341 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -17,13 +17,12 @@ __all__ = ["Dispatcher", "get_dispatch_name"] -Node = TypeVar("Node") +Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]") Getter: TypeAlias = Callable[[Any], Any] Raiser: TypeAlias = Callable[..., "Never"] # TODO @dangotbanned: Clean up `__call__` comments -# TODO @dangotbanned: Bound `Node`? @final class Dispatcher(Generic[Node]): __slots__ = ("_method_getter", "_name") From 7dfb800f687371d982b36ce5a4abbb09e521d4ff Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 21:54:38 +0000 Subject: [PATCH 13/19] docs: remove camelCase mention oops --- narwhals/_plan/_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index eb0301e341..36dbf83c5a 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -114,7 +114,7 @@ def _(ctx: Any, /) -> Any: def _pascal_to_snake_case(s: str) -> str: - """Convert a PascalCase, camelCase string to snake_case. + """Convert a PascalCase string to snake_case. Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 """ From 23d3e924657cd98fa1cba8e1b7d463741ac5e16c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:37:12 +0000 Subject: [PATCH 14/19] now we're talking --- narwhals/_plan/_dispatch.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 36dbf83c5a..12d4ed0db5 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -2,15 +2,16 @@ import re from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable, Generic, final, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, final, overload from narwhals._plan._guards import is_function_expr +from narwhals._plan.compliant.typing import FrameT_contra, R_co from narwhals._typing_compat import TypeVar if TYPE_CHECKING: from typing_extensions import Never, TypeAlias - from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co + from narwhals._plan.compliant.typing import Ctx from narwhals._plan.expressions import ExprIR, FunctionExpr from narwhals._plan.typing import ExprIRT, FunctionT @@ -18,15 +19,27 @@ Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]") -Getter: TypeAlias = Callable[[Any], Any] +Node_contra = TypeVar( + "Node_contra", bound="ExprIR | FunctionExpr[Any]", contravariant=True +) Raiser: TypeAlias = Callable[..., "Never"] +class Binder(Protocol[Node_contra]): + def __call__( + self, ctx: Ctx[FrameT_contra, R_co], / + ) -> BoundMethod[Node_contra, FrameT_contra, R_co]: ... + + +class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]): + def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ... + + # TODO @dangotbanned: Clean up `__call__` comments @final class Dispatcher(Generic[Node]): - __slots__ = ("_method_getter", "_name") - _method_getter: Getter + __slots__ = ("_bind", "_name") + _bind: Binder[Node] _name: str @property @@ -47,12 +60,11 @@ def __call__( # raises when the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` # gives a more helpful error for things that are namespaced like `col("a").str.replace` try: - bound_method = self._method_getter(ctx) + method = self._bind(ctx) except AttributeError: raise self._not_implemented_error(ctx) from None - - if result := bound_method(node, frame, name): - return result # type: ignore[no-any-return] + if result := method(node, frame, name): + return result # here if is defined on `CompliantExpr`, but not on ctx raise self._not_implemented_error(ctx) @@ -78,14 +90,14 @@ def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]: obj._name = _method_name(tp) getter = attrgetter(obj._name) is_namespaced = tp.__expr_ir_config__.is_namespaced - obj._method_getter = _via_namespace(getter) if is_namespaced else getter + obj._bind = _via_namespace(getter) if is_namespaced else getter return obj @staticmethod def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: obj = Dispatcher.__new__(Dispatcher) obj._name = tp.__name__ - obj._method_getter = obj._make_no_dispatch_error() + obj._bind = obj._make_no_dispatch_error() return obj def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]: @@ -106,7 +118,7 @@ def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: return NotImplementedError(msg) -def _via_namespace(getter: Getter, /) -> Getter: +def _via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: def _(ctx: Any, /) -> Any: return getter(ctx.__narwhals_namespace__()) From 93bdd6ef5437930474a79a4640cdc62ce336ea51 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 17 Oct 2025 23:01:45 +0000 Subject: [PATCH 15/19] docs: Add `Dispatcher` doc it ain't much, but its a start --- narwhals/_plan/_dispatch.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 12d4ed0db5..85d70a2919 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -38,6 +38,18 @@ def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_c # TODO @dangotbanned: Clean up `__call__` comments @final class Dispatcher(Generic[Node]): + """Translate class definitions into error-wrapped method calls. + + Operates over `ExprIR` and `Function` nodes. + By default, we dispatch to the compliant-level by calling a method that is the + **snake_case**-equivalent of the class name: + + class BinaryExpr(ExprIR): ... + + class CompliantExpr(Protocol): + def binary_expr(self, *args: Any): ... + """ + __slots__ = ("_bind", "_name") _bind: Binder[Node] _name: str From 725b316b5e9f99e34102161420f5f8832f155440 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 18 Oct 2025 12:56:53 +0000 Subject: [PATCH 16/19] feat: Better-distinguish dev-error from backend not implemented --- narwhals/_plan/_dispatch.py | 34 +++++++++++++++++---------- tests/plan/dispatch_test.py | 46 +++++++++++++++++++++++-------------- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 85d70a2919..5b79f8faa1 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -1,8 +1,9 @@ from __future__ import annotations import re +from collections.abc import Callable from operator import attrgetter -from typing import TYPE_CHECKING, Any, Callable, Generic, Protocol, final, overload +from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, final, overload from narwhals._plan._guards import is_function_expr from narwhals._plan.compliant.typing import FrameT_contra, R_co @@ -35,7 +36,6 @@ class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]): def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ... -# TODO @dangotbanned: Clean up `__call__` comments @final class Dispatcher(Generic[Node]): """Translate class definitions into error-wrapped method calls. @@ -61,6 +61,14 @@ def name(self) -> str: def __repr__(self) -> str: return f"{type(self).__name__}<{self.name}>" + def bind( + self, ctx: Ctx[FrameT_contra, R_co], / + ) -> BoundMethod[Node, FrameT_contra, R_co]: + try: + return self._bind(ctx) + except AttributeError: + raise self._not_implemented_error(ctx, "compliant") from None + def __call__( self, ctx: Ctx[FrameT_contra, R_co], @@ -69,16 +77,10 @@ def __call__( name: str, /, ) -> R_co: - # raises when the method isn't implemented on `CompliantExpr`, but exists as a method on `Expr` - # gives a more helpful error for things that are namespaced like `col("a").str.replace` - try: - method = self._bind(ctx) - except AttributeError: - raise self._not_implemented_error(ctx) from None + method = self.bind(ctx) if result := method(node, frame, name): return result - # here if is defined on `CompliantExpr`, but not on ctx - raise self._not_implemented_error(ctx) + raise self._not_implemented_error(ctx, "context") @staticmethod def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: @@ -125,8 +127,16 @@ def getter(_: Any, /) -> Raiser: return getter - def _not_implemented_error(self, ctx: object, /) -> NotImplementedError: - msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" + def _not_implemented_error( + self, ctx: object, /, missing: Literal["compliant", "context"] + ) -> NotImplementedError: + if missing == "context": + msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" + else: + msg = ( + f"`{self.name}` has not been implemented at the compliant-level.\n" + f"Hint: Try adding `CompliantExpr.{self.name}()` or `CompliantNamespace.{self.name}()`" + ) return NotImplementedError(msg) diff --git a/tests/plan/dispatch_test.py b/tests/plan/dispatch_test.py index 4d3fcb2c79..db65d468a5 100644 --- a/tests/plan/dispatch_test.py +++ b/tests/plan/dispatch_test.py @@ -13,10 +13,18 @@ from tests.plan.utils import assert_equal_data, dataframe, named_ir if TYPE_CHECKING: + import sys + import pyarrow as pa + from typing_extensions import TypeAlias from narwhals._plan.dataframe import DataFrame + if sys.version_info >= (3, 11): + _Flags: TypeAlias = "int | re.RegexFlag" + else: + _Flags: TypeAlias = int + @pytest.fixture def data() -> dict[str, Any]: @@ -33,36 +41,40 @@ def df(data: dict[str, Any]) -> DataFrame[pa.Table, pa.ChunkedArray[Any]]: return dataframe(data) +def re_compile( + pattern: str, flags: _Flags = re.DOTALL | re.IGNORECASE +) -> re.Pattern[str]: + return re.compile(pattern, flags) + + def test_dispatch(df: DataFrame[pa.Table, pa.ChunkedArray[Any]]) -> None: implemented_full = nwp.col("a").is_null() - only_at_compliant_level = nwp.col("c").ewm_mean() - only_at_narwhals_level = nwp.col("d").str.contains("a") forgot_to_expand = (named_ir("howdy", nwp.nth(3, 4).first()),) aliased_after_expand: tuple[ir.NamedIR[Any]] = ( ir.NamedIR.from_ir(ir.col("a").alias("b")), ) - pattern_expand = re.compile( - r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first", - re.DOTALL | re.IGNORECASE, - ) - bad = re.escape("col('a').alias('b')") - pattern_aliased_after_expand = re.compile( - rf"Alias.+not.+appear.+got.+{bad}", re.DOTALL | re.IGNORECASE - ) - assert_equal_data(df.select(implemented_full), {"a": [False, True, False]}) - with pytest.raises(NotImplementedError, match=r"ewm_mean"): - df.select(only_at_compliant_level) + missing_backend = r"ewm_mean.+is not yet implemented for" + with pytest.raises(NotImplementedError, match=missing_backend): + df.select(nwp.col("c").ewm_mean()) - with pytest.raises(NotImplementedError, match=r"str\.contains"): - df.select(only_at_narwhals_level) + missing_protocol = re_compile( + r"str\.contains.+has not been implemented.+compliant.+" + r"Hint.+try adding.+CompliantExpr\.str\.contains\(\)" + ) + with pytest.raises(NotImplementedError, match=missing_protocol): + df.select(nwp.col("d").str.contains("a")) - with pytest.raises(TypeError, match=pattern_expand): + with pytest.raises( + TypeError, + match=re_compile(r"IndexColumns.+not.+appear.+compliant.+expand.+expr.+first"), + ): df._compliant.select(forgot_to_expand) - with pytest.raises(TypeError, match=pattern_aliased_after_expand): + bad = re.escape("col('a').alias('b')") + with pytest.raises(TypeError, match=re_compile(rf"Alias.+not.+appear.+got.+{bad}")): df._compliant.select(aliased_after_expand) # Not a narwhals method, to make sure this doesn't allow arbitrary calls From 0265d3de0868342da8943c32f503533e9f81d137 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:28:26 +0000 Subject: [PATCH 17/19] more docs --- narwhals/_plan/_dispatch.py | 9 +++++++++ narwhals/_plan/_expr_ir.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index 5b79f8faa1..ed83c00f2e 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -41,6 +41,7 @@ class Dispatcher(Generic[Node]): """Translate class definitions into error-wrapped method calls. Operates over `ExprIR` and `Function` nodes. + By default, we dispatch to the compliant-level by calling a method that is the **snake_case**-equivalent of the class name: @@ -64,6 +65,13 @@ def __repr__(self) -> str: def bind( self, ctx: Ctx[FrameT_contra, R_co], / ) -> BoundMethod[Node, FrameT_contra, R_co]: + """Retrieve the implementation of this expression from `ctx`. + + Binds an instance method, most commonly via: + + expr: CompliantExpr + method = getattr(expr, "method_name") + """ try: return self._bind(ctx) except AttributeError: @@ -77,6 +85,7 @@ def __call__( name: str, /, ) -> R_co: + """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" method = self.bind(ctx) if result := method(node, frame, name): return result diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 40ba704e13..0ecf8f4180 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -52,7 +52,7 @@ def __init_subclass__( def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / ) -> R_co: - """Evaluate expression in `frame`, using `ctx` for implementation(s).""" + """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" return self.__expr_ir_dispatch__(ctx, self, frame, name) def to_narwhals(self, version: Version = Version.MAIN) -> Expr: From 2d9b0374d74d91cdae418b5b18ea7b67bd949872 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:43:23 +0000 Subject: [PATCH 18/19] refactor: align signatures of `dispatch`, `__expr_ir_dispatch__`, `Dispatcher.__call__` Pretty nice property to have --- narwhals/_plan/_dispatch.py | 2 +- narwhals/_plan/_expr_ir.py | 2 +- narwhals/_plan/expressions/expr.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/narwhals/_plan/_dispatch.py b/narwhals/_plan/_dispatch.py index ed83c00f2e..044b786e28 100644 --- a/narwhals/_plan/_dispatch.py +++ b/narwhals/_plan/_dispatch.py @@ -79,8 +79,8 @@ def bind( def __call__( self, - ctx: Ctx[FrameT_contra, R_co], node: Node, + ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, /, diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 0ecf8f4180..3faba2cc0e 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -53,7 +53,7 @@ def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, / ) -> R_co: """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" - return self.__expr_ir_dispatch__(ctx, self, frame, name) + return self.__expr_ir_dispatch__(self, ctx, frame, name) def to_narwhals(self, version: Version = Version.MAIN) -> Expr: from narwhals._plan import expr diff --git a/narwhals/_plan/expressions/expr.py b/narwhals/_plan/expressions/expr.py index 82736e1092..4fa2f6cf6e 100644 --- a/narwhals/_plan/expressions/expr.py +++ b/narwhals/_plan/expressions/expr.py @@ -316,7 +316,7 @@ def __init__( def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.function.__expr_ir_dispatch__(ctx, self, frame, name) + return self.function.__expr_ir_dispatch__(self, ctx, frame, name) class RollingExpr(FunctionExpr[RollingT_co]): ... @@ -330,7 +330,7 @@ class AnonymousExpr( def dispatch( self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str ) -> R_co: - return self.__expr_ir_dispatch__(ctx, self, frame, name) + return self.__expr_ir_dispatch__(self, ctx, frame, name) class RangeExpr(FunctionExpr[RangeT_co]): From d66777ae9417b1b387bd89b997d2f6b131861aeb Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:49:10 +0000 Subject: [PATCH 19/19] =?UTF-8?q?chore(typing):=20Remove=20now-unused=20`I?= =?UTF-8?q?ncomplete`s=20=F0=9F=A5=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- narwhals/_plan/_expr_ir.py | 4 +--- narwhals/_plan/_function.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/narwhals/_plan/_expr_ir.py b/narwhals/_plan/_expr_ir.py index 3faba2cc0e..8f66f2d9fa 100644 --- a/narwhals/_plan/_expr_ir.py +++ b/narwhals/_plan/_expr_ir.py @@ -14,7 +14,7 @@ from collections.abc import Iterator from typing import Any, ClassVar - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co from narwhals._plan.expr import Expr, Selector @@ -23,8 +23,6 @@ from narwhals._plan.typing import ExprIRT2, MapIR, Seq from narwhals.dtypes import DType - Incomplete: TypeAlias = "Any" - class ExprIR(Immutable): """Anything that can be a node on a graph of expressions.""" diff --git a/narwhals/_plan/_function.py b/narwhals/_plan/_function.py index d319f68493..8b71a4dd8d 100644 --- a/narwhals/_plan/_function.py +++ b/narwhals/_plan/_function.py @@ -10,15 +10,13 @@ if TYPE_CHECKING: from typing import Any, Callable, ClassVar - from typing_extensions import Self, TypeAlias + from typing_extensions import Self from narwhals._plan.expressions import ExprIR, FunctionExpr from narwhals._plan.typing import Accessor __all__ = ["Function", "HorizontalFunction"] -Incomplete: TypeAlias = "Any" - class Function(Immutable): """Shared by expr functions and namespace functions.