Skip to content

Commit 34a1fd0

Browse files
authored
refactor(expr-ir): Improve function dispatch (#3215)
1 parent 1ab1599 commit 34a1fd0

File tree

13 files changed

+342
-124
lines changed

13 files changed

+342
-124
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ repos:
7979
entry: "self: Self"
8080
language: pygrep
8181
files: ^narwhals/
82+
# mypy needs `Self` for `ExprIR.dispatch`
83+
exclude: ^narwhals/_plan/.*\.py
8284
- id: dtypes-import
8385
name: don't import from narwhals.dtypes (use `Version.dtypes` instead)
8486
entry: |

narwhals/_plan/_dispatch.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from collections.abc import Callable
5+
from operator import attrgetter
6+
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, final, overload
7+
8+
from narwhals._plan._guards import is_function_expr
9+
from narwhals._plan.compliant.typing import FrameT_contra, R_co
10+
from narwhals._typing_compat import TypeVar
11+
12+
if TYPE_CHECKING:
13+
from typing_extensions import Never, TypeAlias
14+
15+
from narwhals._plan.compliant.typing import Ctx
16+
from narwhals._plan.expressions import ExprIR, FunctionExpr
17+
from narwhals._plan.typing import ExprIRT, FunctionT
18+
19+
__all__ = ["Dispatcher", "get_dispatch_name"]
20+
21+
22+
Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]")
23+
Node_contra = TypeVar(
24+
"Node_contra", bound="ExprIR | FunctionExpr[Any]", contravariant=True
25+
)
26+
Raiser: TypeAlias = Callable[..., "Never"]
27+
28+
29+
class Binder(Protocol[Node_contra]):
30+
def __call__(
31+
self, ctx: Ctx[FrameT_contra, R_co], /
32+
) -> BoundMethod[Node_contra, FrameT_contra, R_co]: ...
33+
34+
35+
class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]):
36+
def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ...
37+
38+
39+
@final
40+
class Dispatcher(Generic[Node]):
41+
"""Translate class definitions into error-wrapped method calls.
42+
43+
Operates over `ExprIR` and `Function` nodes.
44+
45+
By default, we dispatch to the compliant-level by calling a method that is the
46+
**snake_case**-equivalent of the class name:
47+
48+
class BinaryExpr(ExprIR): ...
49+
50+
class CompliantExpr(Protocol):
51+
def binary_expr(self, *args: Any): ...
52+
"""
53+
54+
__slots__ = ("_bind", "_name")
55+
_bind: Binder[Node]
56+
_name: str
57+
58+
@property
59+
def name(self) -> str:
60+
return self._name
61+
62+
def __repr__(self) -> str:
63+
return f"{type(self).__name__}<{self.name}>"
64+
65+
def bind(
66+
self, ctx: Ctx[FrameT_contra, R_co], /
67+
) -> BoundMethod[Node, FrameT_contra, R_co]:
68+
"""Retrieve the implementation of this expression from `ctx`.
69+
70+
Binds an instance method, most commonly via:
71+
72+
expr: CompliantExpr
73+
method = getattr(expr, "method_name")
74+
"""
75+
try:
76+
return self._bind(ctx)
77+
except AttributeError:
78+
raise self._not_implemented_error(ctx, "compliant") from None
79+
80+
def __call__(
81+
self,
82+
node: Node,
83+
ctx: Ctx[FrameT_contra, R_co],
84+
frame: FrameT_contra,
85+
name: str,
86+
/,
87+
) -> R_co:
88+
"""Evaluate this expression in `frame`, using implementation(s) provided by `ctx`."""
89+
method = self.bind(ctx)
90+
if result := method(node, frame, name):
91+
return result
92+
raise self._not_implemented_error(ctx, "context")
93+
94+
@staticmethod
95+
def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]:
96+
if not tp.__expr_ir_config__.allow_dispatch:
97+
return Dispatcher._no_dispatch(tp)
98+
return Dispatcher._from_type(tp)
99+
100+
@staticmethod
101+
def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]:
102+
return Dispatcher._from_type(tp)
103+
104+
@staticmethod
105+
@overload
106+
def _from_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ...
107+
@staticmethod
108+
@overload
109+
def _from_type(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: ...
110+
@staticmethod
111+
def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]:
112+
obj = Dispatcher.__new__(Dispatcher)
113+
obj._name = _method_name(tp)
114+
getter = attrgetter(obj._name)
115+
is_namespaced = tp.__expr_ir_config__.is_namespaced
116+
obj._bind = _via_namespace(getter) if is_namespaced else getter
117+
return obj
118+
119+
@staticmethod
120+
def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]:
121+
obj = Dispatcher.__new__(Dispatcher)
122+
obj._name = tp.__name__
123+
obj._bind = obj._make_no_dispatch_error()
124+
return obj
125+
126+
def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]:
127+
def _no_dispatch_error(node: Node, *_: Any) -> Never:
128+
msg = (
129+
f"{self.name!r} should not appear at the compliant-level.\n\n"
130+
f"Make sure to expand all expressions first, got:\n{node!r}"
131+
)
132+
raise TypeError(msg)
133+
134+
def getter(_: Any, /) -> Raiser:
135+
return _no_dispatch_error
136+
137+
return getter
138+
139+
def _not_implemented_error(
140+
self, ctx: object, /, missing: Literal["compliant", "context"]
141+
) -> NotImplementedError:
142+
if missing == "context":
143+
msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}"
144+
else:
145+
msg = (
146+
f"`{self.name}` has not been implemented at the compliant-level.\n"
147+
f"Hint: Try adding `CompliantExpr.{self.name}()` or `CompliantNamespace.{self.name}()`"
148+
)
149+
return NotImplementedError(msg)
150+
151+
152+
def _via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]:
153+
def _(ctx: Any, /) -> Any:
154+
return getter(ctx.__narwhals_namespace__())
155+
156+
return _
157+
158+
159+
def _pascal_to_snake_case(s: str) -> str:
160+
"""Convert a PascalCase string to snake_case.
161+
162+
Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62
163+
"""
164+
# Handle the sequence of uppercase letters followed by a lowercase letter
165+
snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s)
166+
# Insert an underscore between a lowercase letter and an uppercase letter
167+
return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower()
168+
169+
170+
_PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])")
171+
_PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])")
172+
173+
174+
def _re_repl_snake(match: re.Match[str], /) -> str:
175+
return f"{match.group(1)}_{match.group(2)}"
176+
177+
178+
def _method_name(tp: type[ExprIRT | FunctionT]) -> str:
179+
config = tp.__expr_ir_config__
180+
name = config.override_name or _pascal_to_snake_case(tp.__name__)
181+
return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name
182+
183+
184+
def get_dispatch_name(expr: ExprIR, /) -> str:
185+
"""Return the synthesized method name for `expr`."""
186+
return (
187+
repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name
188+
)

narwhals/_plan/_expr_ir.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Generic, cast
3+
from typing import TYPE_CHECKING, Generic
44

5+
from narwhals._plan._dispatch import Dispatcher
56
from narwhals._plan._guards import is_function_expr, is_literal
67
from narwhals._plan._immutable import Immutable
7-
from narwhals._plan.common import dispatch_getter, replace
8+
from narwhals._plan.common import replace
89
from narwhals._plan.options import ExprIROptions
910
from narwhals._plan.typing import ExprIRT
1011
from narwhals.utils import Version
1112

1213
if TYPE_CHECKING:
13-
from collections.abc import Callable, Iterator
14+
from collections.abc import Iterator
1415
from typing import Any, ClassVar
1516

16-
from typing_extensions import Self, TypeAlias
17+
from typing_extensions import Self
1718

1819
from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co
1920
from narwhals._plan.expr import Expr, Selector
@@ -22,29 +23,6 @@
2223
from narwhals._plan.typing import ExprIRT2, MapIR, Seq
2324
from narwhals.dtypes import DType
2425

25-
Incomplete: TypeAlias = "Any"
26-
27-
28-
def _dispatch_generate(
29-
tp: type[ExprIRT], /
30-
) -> Callable[[Incomplete, ExprIRT, Incomplete, str], Incomplete]:
31-
if not tp.__expr_ir_config__.allow_dispatch:
32-
33-
def _(ctx: Any, /, node: ExprIRT, _: Any, name: str) -> Any:
34-
msg = (
35-
f"{tp.__name__!r} should not appear at the compliant-level.\n\n"
36-
f"Make sure to expand all expressions first, got:\n{ctx!r}\n{node!r}\n{name!r}"
37-
)
38-
raise TypeError(msg)
39-
40-
return _
41-
getter = dispatch_getter(tp)
42-
43-
def _(ctx: Any, /, node: ExprIRT, frame: Any, name: str) -> Any:
44-
return getter(ctx)(node, frame, name)
45-
46-
return _
47-
4826

4927
class ExprIR(Immutable):
5028
"""Anything that can be a node on a graph of expressions."""
@@ -53,9 +31,7 @@ class ExprIR(Immutable):
5331
"""Nested node names, in iteration order."""
5432

5533
__expr_ir_config__: ClassVar[ExprIROptions] = ExprIROptions.default()
56-
__expr_ir_dispatch__: ClassVar[
57-
staticmethod[[Incomplete, Self, Incomplete, str], Incomplete]
58-
]
34+
__expr_ir_dispatch__: ClassVar[Dispatcher[Self]]
5935

6036
def __init_subclass__(
6137
cls: type[Self],
@@ -69,13 +45,13 @@ def __init_subclass__(
6945
cls._child = child
7046
if config:
7147
cls.__expr_ir_config__ = config
72-
cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate(cls))
48+
cls.__expr_ir_dispatch__ = Dispatcher.from_expr_ir(cls)
7349

7450
def dispatch(
75-
self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, /
51+
self: Self, ctx: Ctx[FrameT_contra, R_co], frame: FrameT_contra, name: str, /
7652
) -> R_co:
77-
"""Evaluate expression in `frame`, using `ctx` for implementation(s)."""
78-
return self.__expr_ir_dispatch__(ctx, cast("Self", self), frame, name) # type: ignore[no-any-return]
53+
"""Evaluate this expression in `frame`, using implementation(s) provided by `ctx`."""
54+
return self.__expr_ir_dispatch__(self, ctx, frame, name)
7955

8056
def to_narwhals(self, version: Version = Version.MAIN) -> Expr:
8157
from narwhals._plan import expr

narwhals/_plan/_function.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,21 @@
22

33
from typing import TYPE_CHECKING
44

5+
from narwhals._plan._dispatch import Dispatcher
56
from narwhals._plan._immutable import Immutable
6-
from narwhals._plan.common import dispatch_getter, dispatch_method_name, replace
7+
from narwhals._plan.common import replace
78
from narwhals._plan.options import FEOptions, FunctionOptions
89

910
if TYPE_CHECKING:
1011
from typing import Any, Callable, ClassVar
1112

12-
from typing_extensions import Self, TypeAlias
13+
from typing_extensions import Self
1314

1415
from narwhals._plan.expressions import ExprIR, FunctionExpr
15-
from narwhals._plan.typing import Accessor, FunctionT
16+
from narwhals._plan.typing import Accessor
1617

1718
__all__ = ["Function", "HorizontalFunction"]
1819

19-
Incomplete: TypeAlias = "Any"
20-
21-
22-
def _dispatch_generate_function(
23-
tp: type[FunctionT], /
24-
) -> Callable[[Incomplete, FunctionExpr[FunctionT], Incomplete, str], Incomplete]:
25-
getter = dispatch_getter(tp)
26-
27-
def _(ctx: Any, /, node: FunctionExpr[FunctionT], frame: Any, name: str) -> Any:
28-
return getter(ctx)(node, frame, name)
29-
30-
return _
31-
3220

3321
class Function(Immutable):
3422
"""Shared by expr functions and namespace functions.
@@ -40,9 +28,7 @@ class Function(Immutable):
4028
FunctionOptions.default
4129
)
4230
__expr_ir_config__: ClassVar[FEOptions] = FEOptions.default()
43-
__expr_ir_dispatch__: ClassVar[
44-
staticmethod[[Incomplete, FunctionExpr[Self], Incomplete, str], Incomplete]
45-
]
31+
__expr_ir_dispatch__: ClassVar[Dispatcher[FunctionExpr[Self]]]
4632

4733
@property
4834
def function_options(self) -> FunctionOptions:
@@ -72,10 +58,10 @@ def __init_subclass__(
7258
cls._function_options = staticmethod(options)
7359
if config:
7460
cls.__expr_ir_config__ = config
75-
cls.__expr_ir_dispatch__ = staticmethod(_dispatch_generate_function(cls))
61+
cls.__expr_ir_dispatch__ = Dispatcher.from_function(cls)
7662

7763
def __repr__(self) -> str:
78-
return dispatch_method_name(type(self))
64+
return self.__expr_ir_dispatch__.name
7965

8066

8167
class HorizontalFunction(

narwhals/_plan/arrow/group_by.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import pyarrow.compute as pc # ignore-banned-import
77

88
from narwhals._plan import expressions as ir
9+
from narwhals._plan._dispatch import get_dispatch_name
910
from narwhals._plan._guards import is_agg_expr, is_function_expr
1011
from narwhals._plan.arrow import acero, functions as fn, options
11-
from narwhals._plan.common import dispatch_method_name, temp
12+
from narwhals._plan.common import temp
1213
from narwhals._plan.compliant.group_by import EagerDataFrameGroupBy
1314
from narwhals._plan.expressions import aggregation as agg
1415
from narwhals._utils import Implementation
@@ -132,11 +133,7 @@ def group_by_error(
132133
if reason == "too complex":
133134
msg = "Non-trivial complex aggregation found, which"
134135
else:
135-
if is_function_expr(expr):
136-
func_name = repr(expr.function)
137-
else:
138-
func_name = dispatch_method_name(type(expr))
139-
msg = f"`{func_name}()`"
136+
msg = f"`{get_dispatch_name(expr)}()`"
140137
msg = f"{msg} is not supported in a `group_by` context for {backend!r}:\n{column_name}={expr!r}"
141138
return InvalidOperationError(msg)
142139

0 commit comments

Comments
 (0)