Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ repos:
entry: "self: Self"
language: pygrep
files: ^narwhals/
# mypy needs `Self` for `ExprIR.dispatch`
exclude: ^narwhals/_plan/.*\.py
- id: dtypes-import
name: don't import from narwhals.dtypes (use `Version.dtypes` instead)
entry: |
Expand Down
188 changes: 188 additions & 0 deletions narwhals/_plan/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

import re
from collections.abc import Callable
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, final, overload

from narwhals._plan._guards import is_function_expr
from narwhals._plan.compliant.typing import FrameT_contra, R_co
from narwhals._typing_compat import TypeVar

if TYPE_CHECKING:
from typing_extensions import Never, TypeAlias

from narwhals._plan.compliant.typing import Ctx
from narwhals._plan.expressions import ExprIR, FunctionExpr
from narwhals._plan.typing import ExprIRT, FunctionT

__all__ = ["Dispatcher", "get_dispatch_name"]


Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]")
Node_contra = TypeVar(
"Node_contra", bound="ExprIR | FunctionExpr[Any]", contravariant=True
)
Raiser: TypeAlias = Callable[..., "Never"]


class Binder(Protocol[Node_contra]):
def __call__(
self, ctx: Ctx[FrameT_contra, R_co], /
) -> BoundMethod[Node_contra, FrameT_contra, R_co]: ...


class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]):
def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ...


@final
class Dispatcher(Generic[Node]):
"""Translate class definitions into error-wrapped method calls.

Operates over `ExprIR` and `Function` nodes.

By default, we dispatch to the compliant-level by calling a method that is the
**snake_case**-equivalent of the class name:

class BinaryExpr(ExprIR): ...

class CompliantExpr(Protocol):
def binary_expr(self, *args: Any): ...
"""

__slots__ = ("_bind", "_name")
_bind: Binder[Node]
_name: str

@property
def name(self) -> str:
return self._name

def __repr__(self) -> str:
return f"{type(self).__name__}<{self.name}>"

def bind(
self, ctx: Ctx[FrameT_contra, R_co], /
) -> BoundMethod[Node, FrameT_contra, R_co]:
"""Retrieve the implementation of this expression from `ctx`.

Binds an instance method, most commonly via:

expr: CompliantExpr
method = getattr(expr, "method_name")
"""
try:
return self._bind(ctx)
except AttributeError:
raise self._not_implemented_error(ctx, "compliant") from None

def __call__(
self,
node: Node,
ctx: Ctx[FrameT_contra, R_co],
frame: FrameT_contra,
name: str,
/,
) -> R_co:
"""Evaluate this expression in `frame`, using implementation(s) provided by `ctx`."""
method = self.bind(ctx)
if result := method(node, frame, name):
return result
raise self._not_implemented_error(ctx, "context")

@staticmethod
def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]:
if not tp.__expr_ir_config__.allow_dispatch:
return Dispatcher._no_dispatch(tp)
return Dispatcher._from_type(tp)

@staticmethod
def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]:
return Dispatcher._from_type(tp)

@staticmethod
@overload
def _from_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ...
@staticmethod
@overload
def _from_type(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: ...
@staticmethod
def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]:
obj = Dispatcher.__new__(Dispatcher)
obj._name = _method_name(tp)
getter = attrgetter(obj._name)
is_namespaced = tp.__expr_ir_config__.is_namespaced
obj._bind = _via_namespace(getter) if is_namespaced else getter
return obj

@staticmethod
def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]:
obj = Dispatcher.__new__(Dispatcher)
obj._name = tp.__name__
obj._bind = obj._make_no_dispatch_error()
return obj

def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]:
def _no_dispatch_error(node: Node, *_: Any) -> Never:
msg = (
f"{self.name!r} should not appear at the compliant-level.\n\n"
f"Make sure to expand all expressions first, got:\n{node!r}"
)
raise TypeError(msg)

def getter(_: Any, /) -> Raiser:
return _no_dispatch_error

return getter

def _not_implemented_error(
self, ctx: object, /, missing: Literal["compliant", "context"]
) -> NotImplementedError:
if missing == "context":
msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}"
else:
msg = (
f"`{self.name}` has not been implemented at the compliant-level.\n"
f"Hint: Try adding `CompliantExpr.{self.name}()` or `CompliantNamespace.{self.name}()`"
)
return NotImplementedError(msg)


def _via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]:
def _(ctx: Any, /) -> Any:
return getter(ctx.__narwhals_namespace__())

return _


def _pascal_to_snake_case(s: str) -> str:
"""Convert a PascalCase string to snake_case.

Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62
"""
# Handle the sequence of uppercase letters followed by a lowercase letter
snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s)
# Insert an underscore between a lowercase letter and an uppercase letter
return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower()


_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])")
_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])")


def _re_repl_snake(match: re.Match[str], /) -> str:
return f"{match.group(1)}_{match.group(2)}"


def _method_name(tp: type[ExprIRT | FunctionT]) -> str:
config = tp.__expr_ir_config__
name = config.override_name or _pascal_to_snake_case(tp.__name__)
return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name


def get_dispatch_name(expr: ExprIR, /) -> str:
"""Return the synthesized method name for `expr`."""
return (
repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name
)
44 changes: 10 additions & 34 deletions narwhals/_plan/_expr_ir.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, cast
from typing import TYPE_CHECKING, Generic

from narwhals._plan._dispatch import Dispatcher
from narwhals._plan._guards import is_function_expr, is_literal
from narwhals._plan._immutable import Immutable
from narwhals._plan.common import dispatch_getter, replace
from narwhals._plan.common import replace
from narwhals._plan.options import ExprIROptions
from narwhals._plan.typing import ExprIRT
from narwhals.utils import Version

if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from collections.abc import Iterator
from typing import Any, ClassVar

from typing_extensions import Self, TypeAlias
from typing_extensions import Self

from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co
from narwhals._plan.expr import Expr, Selector
Expand All @@ -22,29 +23,6 @@
from narwhals._plan.typing import ExprIRT2, MapIR, Seq
from narwhals.dtypes import DType

Incomplete: TypeAlias = "Any"


def _dispatch_generate(
tp: type[ExprIRT], /
) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]:
if not tp.__expr_ir_config__.allow_dispatch:

def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any:
msg = (
f"{tp.__name__!r} should not appear at the compliant-level.\n\n"
f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}"
)
raise TypeError(msg)

return _
getter = dispatch_getter(tp)

def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any:
return getter(ctx)(node, frame, name)

return _


class ExprIR(Immutable):
"""Anything that can be a node on a graph of expressions."""
Expand All @@ -53,9 +31,7 @@ class ExprIR(Immutable):
"""Nested node names, in iteration order."""

__expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default()
__expr_ir_dispatch__: ClassVar[
staticmethod[[Incomplete, Self, Incomplete, str], Incomplete]
]
__expr_ir_dispatch__: ClassVar[Dispatcher[Self]]

def __init_subclass__(
cls: type[Self],
Expand All @@ -69,13 +45,13 @@ def __init_subclass__(
cls._child = child
if config:
cls.__expr_ir_config__ = config
cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls))
cls.__expr_ir_dispatch__ = Dispatcher.from_expr_ir(cls)

def dispatch(
self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, /
self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, /
) -> R_co:
"""Evaluate expression in `frame`, using `ctx` for implementation(s)."""
return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return]
"""Evaluate this expression in `frame`, using implementation(s) provided by `ctx`."""
return self.__expr_ir_dispatch__(self, ctx, frame, name)

def to_narwhals(self, version: Version = Version.MAIN) -> Expr:
from narwhals._plan import expr
Expand Down
28 changes: 7 additions & 21 deletions narwhals/_plan/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,21 @@

from typing import TYPE_CHECKING

from narwhals._plan._dispatch import Dispatcher
from narwhals._plan._immutable import Immutable
from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace
from narwhals._plan.common import replace
from narwhals._plan.options import FEOptions, FunctionOptions

if TYPE_CHECKING:
from typing import Any, Callable, ClassVar

from typing_extensions import Self, TypeAlias
from typing_extensions import Self

from narwhals._plan.expressions import ExprIR, FunctionExpr
from narwhals._plan.typing import Accessor, FunctionT
from narwhals._plan.typing import Accessor

__all__ = ["Function", "HorizontalFunction"]

Incomplete: TypeAlias = "Any"


def _dispatch_generate_function(
tp: type[FunctionT], /
) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]:
getter = dispatch_getter(tp)

def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any:
return getter(ctx)(node, frame, name)

return _


class Function(Immutable):
"""Shared by expr functions and namespace functions.
Expand All @@ -40,9 +28,7 @@ class Function(Immutable):
FunctionOptions.default
)
__expr_ir_config__: ClassVar[FEOptions] = FEOptions.default()
__expr_ir_dispatch__: ClassVar[
staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete]
]
Comment on lines -43 to -45
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truly wild that I wrote ^^^ 😂

The new version looks a bit more understandable from the outside - but it now can correctly scope all the things which were Incomplete.

__expr_ir_dispatch__: ClassVar[Dispatcher[FunctionExpr[Self]]]

__expr_ir_dispatch__: ClassVar[Dispatcher[FunctionExpr[Self]]]

@property
def function_options(self) -> FunctionOptions:
Expand Down Expand Up @@ -72,10 +58,10 @@ def __init_subclass__(
cls._function_options = staticmethod(options)
if config:
cls.__expr_ir_config__ = config
cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls))
cls.__expr_ir_dispatch__ = Dispatcher.from_function(cls)

def __repr__(self) -> str:
return dispatch_method_name(type(self))
return self.__expr_ir_dispatch__.name


class HorizontalFunction(
Expand Down
9 changes: 3 additions & 6 deletions narwhals/_plan/arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import pyarrow.compute as pc # ignore-banned-import

from narwhals._plan import expressions as ir
from narwhals._plan._dispatch import get_dispatch_name
from narwhals._plan._guards import is_agg_expr, is_function_expr
from narwhals._plan.arrow import acero, functions as fn, options
from narwhals._plan.common import dispatch_method_name, temp
from narwhals._plan.common import temp
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
from narwhals._plan.expressions import aggregation as agg
from narwhals._utils import Implementation
Expand Down Expand Up @@ -132,11 +133,7 @@ def group_by_error(
if reason == "too complex":
msg = "Non-trivial complex aggregation found, which"
else:
if is_function_expr(expr):
func_name = repr(expr.function)
else:
func_name = dispatch_method_name(type(expr))
msg = f"`{func_name}()`"
msg = f"`{get_dispatch_name(expr)}()`"
msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}"
return InvalidOperationError(msg)

Expand Down
Loading
Loading