-
Notifications
You must be signed in to change notification settings - Fork 170
refactor(expr-ir): Improve function dispatch #3215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
5fd705d
chore: Un-stash `DispatchGetter` idea
dangotbanned d70ae4e
refactor: Move everything for the 97th time
dangotbanned 23cbff0
fix: update import
dangotbanned 4b08a6e
rewrite as `Dispatcher`
dangotbanned 02d7e41
refactor: un-expose `_pascal_to_snake_case`
dangotbanned 40504f0
less-bad no_dispatch
dangotbanned c07972b
tweak tweak tweak
dangotbanned b26e833
refactor(typing): Less no-any-return
dangotbanned cc73b4a
refactor: `origin` -> `is_namespaced`
dangotbanned 04225e7
chore: Update TODO
dangotbanned deaf691
refactor: `_from_configured_type` -> `_from_type`
dangotbanned 51e13f7
chore(typing): Give `Node` an upper bound
dangotbanned 7dfb800
docs: remove camelCase mention
dangotbanned 23d3e92
now we're talking
dangotbanned 93bdd6e
docs: Add `Dispatcher` doc
dangotbanned 725b316
feat: Better-distinguish dev-error from backend not implemented
dangotbanned 0265d3d
more docs
dangotbanned 2d9b037
refactor: align signatures of `dispatch`, `__expr_ir_dispatch__`, `Di…
dangotbanned d66777a
chore(typing): Remove now-unused `Incomplete`s 🥳
dangotbanned File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import re | ||
| from collections.abc import Callable | ||
| from operator import attrgetter | ||
| from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, final, overload | ||
|
|
||
| from narwhals._plan._guards import is_function_expr | ||
| from narwhals._plan.compliant.typing import FrameT_contra, R_co | ||
| from narwhals._typing_compat import TypeVar | ||
|
|
||
| if TYPE_CHECKING: | ||
| from typing_extensions import Never, TypeAlias | ||
|
|
||
| from narwhals._plan.compliant.typing import Ctx | ||
| from narwhals._plan.expressions import ExprIR, FunctionExpr | ||
| from narwhals._plan.typing import ExprIRT, FunctionT | ||
|
|
||
| __all__ = ["Dispatcher", "get_dispatch_name"] | ||
|
|
||
|
|
||
| Node = TypeVar("Node", bound="ExprIR | FunctionExpr[Any]") | ||
| Node_contra = TypeVar( | ||
| "Node_contra", bound="ExprIR | FunctionExpr[Any]", contravariant=True | ||
| ) | ||
| Raiser: TypeAlias = Callable[..., "Never"] | ||
|
|
||
|
|
||
| class Binder(Protocol[Node_contra]): | ||
| def __call__( | ||
| self, ctx: Ctx[FrameT_contra, R_co], / | ||
| ) -> BoundMethod[Node_contra, FrameT_contra, R_co]: ... | ||
|
|
||
|
|
||
| class BoundMethod(Protocol[Node_contra, FrameT_contra, R_co]): | ||
| def __call__(self, node: Node_contra, frame: FrameT_contra, name: str, /) -> R_co: ... | ||
|
|
||
|
|
||
| @final | ||
| class Dispatcher(Generic[Node]): | ||
| """Translate class definitions into error-wrapped method calls. | ||
|
|
||
| Operates over `ExprIR` and `Function` nodes. | ||
|
|
||
| By default, we dispatch to the compliant-level by calling a method that is the | ||
| **snake_case**-equivalent of the class name: | ||
|
|
||
| class BinaryExpr(ExprIR): ... | ||
|
|
||
| class CompliantExpr(Protocol): | ||
| def binary_expr(self, *args: Any): ... | ||
| """ | ||
|
|
||
| __slots__ = ("_bind", "_name") | ||
| _bind: Binder[Node] | ||
| _name: str | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| return self._name | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{type(self).__name__}<{self.name}>" | ||
|
|
||
| def bind( | ||
| self, ctx: Ctx[FrameT_contra, R_co], / | ||
| ) -> BoundMethod[Node, FrameT_contra, R_co]: | ||
| """Retrieve the implementation of this expression from `ctx`. | ||
|
|
||
| Binds an instance method, most commonly via: | ||
|
|
||
| expr: CompliantExpr | ||
| method = getattr(expr, "method_name") | ||
| """ | ||
| try: | ||
| return self._bind(ctx) | ||
| except AttributeError: | ||
| raise self._not_implemented_error(ctx, "compliant") from None | ||
|
|
||
| def __call__( | ||
| self, | ||
| node: Node, | ||
| ctx: Ctx[FrameT_contra, R_co], | ||
| frame: FrameT_contra, | ||
| name: str, | ||
| /, | ||
| ) -> R_co: | ||
| """Evaluate this expression in `frame`, using implementation(s) provided by `ctx`.""" | ||
| method = self.bind(ctx) | ||
| if result := method(node, frame, name): | ||
| return result | ||
| raise self._not_implemented_error(ctx, "context") | ||
|
|
||
| @staticmethod | ||
| def from_expr_ir(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: | ||
| if not tp.__expr_ir_config__.allow_dispatch: | ||
| return Dispatcher._no_dispatch(tp) | ||
| return Dispatcher._from_type(tp) | ||
|
|
||
| @staticmethod | ||
| def from_function(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: | ||
| return Dispatcher._from_type(tp) | ||
|
|
||
| @staticmethod | ||
| @overload | ||
| def _from_type(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: ... | ||
| @staticmethod | ||
| @overload | ||
| def _from_type(tp: type[FunctionT], /) -> Dispatcher[FunctionExpr[FunctionT]]: ... | ||
| @staticmethod | ||
| def _from_type(tp: type[ExprIRT | FunctionT], /) -> Dispatcher[Any]: | ||
| obj = Dispatcher.__new__(Dispatcher) | ||
| obj._name = _method_name(tp) | ||
| getter = attrgetter(obj._name) | ||
| is_namespaced = tp.__expr_ir_config__.is_namespaced | ||
| obj._bind = _via_namespace(getter) if is_namespaced else getter | ||
| return obj | ||
|
|
||
| @staticmethod | ||
| def _no_dispatch(tp: type[ExprIRT], /) -> Dispatcher[ExprIRT]: | ||
| obj = Dispatcher.__new__(Dispatcher) | ||
| obj._name = tp.__name__ | ||
| obj._bind = obj._make_no_dispatch_error() | ||
| return obj | ||
|
|
||
| def _make_no_dispatch_error(self) -> Callable[[Any], Raiser]: | ||
| def _no_dispatch_error(node: Node, *_: Any) -> Never: | ||
| msg = ( | ||
| f"{self.name!r} should not appear at the compliant-level.\n\n" | ||
| f"Make sure to expand all expressions first, got:\n{node!r}" | ||
| ) | ||
| raise TypeError(msg) | ||
|
|
||
| def getter(_: Any, /) -> Raiser: | ||
| return _no_dispatch_error | ||
|
|
||
| return getter | ||
|
|
||
| def _not_implemented_error( | ||
| self, ctx: object, /, missing: Literal["compliant", "context"] | ||
| ) -> NotImplementedError: | ||
| if missing == "context": | ||
| msg = f"`{self.name}` is not yet implemented for {type(ctx).__name__!r}" | ||
| else: | ||
| msg = ( | ||
| f"`{self.name}` has not been implemented at the compliant-level.\n" | ||
| f"Hint: Try adding `CompliantExpr.{self.name}()` or `CompliantNamespace.{self.name}()`" | ||
| ) | ||
| return NotImplementedError(msg) | ||
|
|
||
|
|
||
| def _via_namespace(getter: Callable[[Any], Any], /) -> Callable[[Any], Any]: | ||
| def _(ctx: Any, /) -> Any: | ||
| return getter(ctx.__narwhals_namespace__()) | ||
|
|
||
| return _ | ||
|
|
||
|
|
||
| def _pascal_to_snake_case(s: str) -> str: | ||
| """Convert a PascalCase string to snake_case. | ||
|
|
||
| Adapted from https://github.com/pydantic/pydantic/blob/f7a9b73517afecf25bf898e3b5f591dffe669778/pydantic/alias_generators.py#L43-L62 | ||
| """ | ||
| # Handle the sequence of uppercase letters followed by a lowercase letter | ||
| snake = _PATTERN_UPPER_LOWER.sub(_re_repl_snake, s) | ||
| # Insert an underscore between a lowercase letter and an uppercase letter | ||
| return _PATTERN_LOWER_UPPER.sub(_re_repl_snake, snake).lower() | ||
|
|
||
|
|
||
| _PATTERN_UPPER_LOWER = re.compile(r"([A-Z]+)([A-Z][a-z])") | ||
| _PATTERN_LOWER_UPPER = re.compile(r"([a-z])([A-Z])") | ||
|
|
||
|
|
||
| def _re_repl_snake(match: re.Match[str], /) -> str: | ||
| return f"{match.group(1)}_{match.group(2)}" | ||
|
|
||
|
|
||
| def _method_name(tp: type[ExprIRT | FunctionT]) -> str: | ||
| config = tp.__expr_ir_config__ | ||
| name = config.override_name or _pascal_to_snake_case(tp.__name__) | ||
| return f"{ns}.{name}" if (ns := getattr(config, "accessor_name", "")) else name | ||
|
|
||
|
|
||
| def get_dispatch_name(expr: ExprIR, /) -> str: | ||
| """Return the synthesized method name for `expr`.""" | ||
| return ( | ||
| repr(expr.function) if is_function_expr(expr) else expr.__expr_ir_dispatch__.name | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Truly wild that I wrote ^^^ 😂
The new version looks a bit more understandable from the outside - but it now can correctly scope all the things which were
Incomplete.