Skip to content

Commit 87b8402

Browse files
committed
very wip selectors
None of this is functional yet - Most of the upstream stuff is written in `python` - The rust enum isn't very helpful for us - (`polars_plan::dsl::selector::Selector`) - It is opaque to the kind of selection
1 parent 309db1f commit 87b8402

File tree

6 files changed

+277
-13
lines changed

6 files changed

+277
-13
lines changed

narwhals/_plan/dummy.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from narwhals.utils import Version, _hasattr_static, flatten
2828

2929
if TYPE_CHECKING:
30-
from typing_extensions import Self
30+
from typing_extensions import Never, Self
3131

3232
from narwhals._plan.common import ExprIR, IntoExpr, IntoExprColumn, Seq, Udf
3333
from narwhals._plan.meta import ExprIRMetaNamespace
@@ -439,10 +439,77 @@ def meta(self) -> ExprIRMetaNamespace:
439439
return ExprIRMetaNamespace(self._ir)
440440

441441

442+
class DummySelector(DummyExpr):
443+
_ir: expr.SelectorIR
444+
445+
@classmethod
446+
def _from_ir(cls, ir: expr.SelectorIR, /) -> Self: # type: ignore[override]
447+
obj = cls.__new__(cls)
448+
obj._ir = ir
449+
return obj
450+
451+
def _to_expr(self) -> DummyExpr:
452+
return self._ir.to_narwhals(self.version)
453+
454+
# TODO @dangotbanned: Make a decision on selector root, binary op
455+
# Current typing warnings are accurate, this isn't valid yet
456+
def __or__(self, other: t.Any) -> Self | t.Any:
457+
if isinstance(other, type(self)):
458+
op = ops.Or()
459+
return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type]
460+
return self._to_expr() | other
461+
462+
def __and__(self, other: t.Any) -> Self | t.Any:
463+
if isinstance(other, type(self)):
464+
op = ops.And()
465+
return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type]
466+
return self._to_expr() & other
467+
468+
def __sub__(self, other: t.Any) -> Self | t.Any:
469+
if isinstance(other, type(self)):
470+
op = ops.Sub()
471+
return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type]
472+
return self._to_expr() - other
473+
474+
def __xor__(self, other: t.Any) -> Self | t.Any:
475+
if isinstance(other, type(self)):
476+
op = ops.ExclusiveOr()
477+
return self._from_ir(op.to_binary_selector(self._ir, other._ir)) # type: ignore[arg-type]
478+
return self._to_expr() ^ other
479+
480+
def __invert__(self) -> Never:
481+
raise NotImplementedError
482+
483+
def __add__(self, other: t.Any) -> DummyExpr: # type: ignore[override]
484+
if isinstance(other, type(self)):
485+
msg = "unsupported operand type(s) for op: ('Selector' + 'Selector')"
486+
raise TypeError(msg)
487+
return self._to_expr() + other # type: ignore[no-any-return]
488+
489+
def __rsub__(self, other: t.Any) -> Never:
490+
raise NotImplementedError
491+
492+
def __rand__(self, other: t.Any) -> Never:
493+
raise NotImplementedError
494+
495+
def __ror__(self, other: t.Any) -> Never:
496+
raise NotImplementedError
497+
498+
def __rxor__(self, other: t.Any) -> Never:
499+
raise NotImplementedError
500+
501+
def __radd__(self, other: t.Any) -> Never:
502+
raise NotImplementedError
503+
504+
442505
class DummyExprV1(DummyExpr):
443506
_version: t.ClassVar[Version] = Version.V1
444507

445508

509+
class DummySelectorV1(DummySelector):
510+
_version: t.ClassVar[Version] = Version.V1
511+
512+
446513
class DummyCompliantExpr:
447514
_ir: ExprIR
448515
_version: Version

narwhals/_plan/expr.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,25 @@
1313
import typing as t
1414

1515
from narwhals._plan.common import ExprIR
16-
from narwhals._plan.typing import FunctionT, LeftT, OperatorT, RightT, RollingT
16+
from narwhals._plan.typing import (
17+
FunctionT,
18+
LeftT,
19+
OperatorT,
20+
RightT,
21+
RollingT,
22+
SelectorOperatorT,
23+
)
24+
from narwhals.utils import Version
1725

1826
if t.TYPE_CHECKING:
1927
from typing_extensions import Self
2028

2129
from narwhals._plan.common import Seq
30+
from narwhals._plan.dummy import DummySelector
2231
from narwhals._plan.functions import MapBatches # noqa: F401
2332
from narwhals._plan.literal import LiteralValue
2433
from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions
34+
from narwhals._plan.selectors import Selector
2535
from narwhals._plan.window import Window
2636
from narwhals.dtypes import DType
2737

@@ -408,9 +418,32 @@ def __repr__(self) -> str:
408418
return "*"
409419

410420

411-
class Selector(ExprIR):
421+
class SelectorIR(ExprIR):
422+
"""Not sure on this separation.
423+
424+
- Need a cleaner way of including `BinarySelector`.
425+
- Like that there's easy access to operands
426+
- Dislike that it inherits node iteration, since upstream doesn't use it for selectors
427+
"""
428+
429+
__slots__ = ("selector",)
430+
431+
selector: Selector
412432
"""by_dtype, matches, numeric, boolean, string, categorical, datetime, all."""
413433

434+
def to_narwhals(self, version: Version = Version.MAIN) -> DummySelector:
435+
from narwhals._plan import dummy
436+
437+
if version is Version.MAIN:
438+
return dummy.DummySelector._from_ir(self)
439+
return dummy.DummySelectorV1._from_ir(self)
440+
441+
442+
class BinarySelector(
443+
BinaryExpr["SelectorIR", SelectorOperatorT, "SelectorIR"],
444+
t.Generic[SelectorOperatorT],
445+
): ...
446+
414447

415448
class Ternary(ExprIR):
416449
"""When-Then-Otherwise.

narwhals/_plan/meta.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _expr_output_name(ir: ExprIR) -> str | ComputeError:
114114
def _has_multiple_outputs(ir: ExprIR) -> bool:
115115
from narwhals._plan import expr
116116

117-
return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.Selector, expr.All))
117+
return isinstance(ir, (expr.Columns, expr.IndexColumns, expr.SelectorIR, expr.All))
118118

119119

120120
def _is_literal(ir: ExprIR, *, allow_aliasing: bool) -> bool:
@@ -145,7 +145,7 @@ def _is_column_selection(ir: ExprIR, *, allow_aliasing: bool) -> bool:
145145
expr.Exclude,
146146
expr.Nth,
147147
expr.IndexColumns,
148-
expr.Selector,
148+
expr.SelectorIR,
149149
expr.All,
150150
),
151151
):

narwhals/_plan/operators.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
if TYPE_CHECKING:
66
from typing_extensions import Self
77

8-
from narwhals._plan.expr import BinaryExpr
8+
from narwhals._plan.expr import BinaryExpr, BinarySelector, SelectorIR
99
from narwhals._plan.typing import LeftT, RightT
1010

1111
from narwhals._plan.common import Immutable
@@ -14,8 +14,8 @@
1414
class Operator(Immutable):
1515
def __repr__(self) -> str:
1616
tp = type(self)
17-
if tp is Operator:
18-
return "Operator"
17+
if tp in {Operator, SelectorOperator}:
18+
return tp.__name__
1919
m = {
2020
Eq: "==",
2121
NotEq: "!=",
@@ -43,6 +43,22 @@ def to_binary_expr(
4343
return BinaryExpr(left=left, op=self, right=right)
4444

4545

46+
class SelectorOperator(Operator):
47+
"""Operators that can *also* be used in selectors.
48+
49+
Remember that `Or` is named [`meta._selector_add`]!
50+
51+
[`meta._selector_add`]: https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/meta.rs#L113-L124
52+
"""
53+
54+
def to_binary_selector(
55+
self, left: SelectorIR, right: SelectorIR, /
56+
) -> BinarySelector[Self]:
57+
from narwhals._plan.expr import BinarySelector
58+
59+
return BinarySelector(left=left, op=self, right=right)
60+
61+
4662
class Eq(Operator): ...
4763

4864

@@ -64,7 +80,7 @@ class GtEq(Operator): ...
6480
class Add(Operator): ...
6581

6682

67-
class Sub(Operator): ...
83+
class Sub(SelectorOperator): ...
6884

6985

7086
class Multiply(Operator): ...
@@ -79,10 +95,10 @@ class FloorDivide(Operator): ...
7995
class Modulus(Operator): ...
8096

8197

82-
class And(Operator): ...
98+
class And(SelectorOperator): ...
8399

84100

85-
class Or(Operator): ...
101+
class Or(SelectorOperator): ...
86102

87103

88-
class ExclusiveOr(Operator): ...
104+
class ExclusiveOr(SelectorOperator): ...

narwhals/_plan/selectors.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Deviations from `polars`.
2+
3+
- A `Selector` corresponds to a `nw.selectors` function
4+
- Binary ops are represented as a subtype of `BinaryExpr`
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import re
10+
from typing import TYPE_CHECKING, Iterable
11+
12+
from narwhals._plan.common import Immutable, is_iterable_reject
13+
from narwhals.utils import _parse_time_unit_and_time_zone
14+
15+
if TYPE_CHECKING:
16+
from datetime import timezone
17+
from typing import Iterator, TypeVar
18+
19+
from narwhals._plan.dummy import DummySelector
20+
from narwhals._plan.expr import SelectorIR
21+
from narwhals.dtypes import DType
22+
from narwhals.typing import TimeUnit
23+
24+
T = TypeVar("T")
25+
26+
27+
class Selector(Immutable):
28+
def to_selector(self) -> SelectorIR:
29+
from narwhals._plan.expr import SelectorIR
30+
31+
return SelectorIR(selector=self)
32+
33+
34+
class All(Selector): ...
35+
36+
37+
class ByDType(Selector):
38+
__slots__ = ("dtypes",)
39+
40+
dtypes: frozenset[DType | type[DType]]
41+
42+
@staticmethod
43+
def from_dtypes(
44+
*dtypes: DType | type[DType] | Iterable[DType | type[DType]],
45+
) -> ByDType:
46+
return ByDType(dtypes=frozenset(_flatten_hash_safe(dtypes)))
47+
48+
49+
class Boolean(Selector): ...
50+
51+
52+
class Categorical(Selector): ...
53+
54+
55+
class Datetime(Selector):
56+
"""Should swallow the [`utils` functions].
57+
58+
Just re-wrapping them for now, since `CompliantSelectorNamespace` is still using them.
59+
60+
[`utils` functions]: https://github.com/narwhals-dev/narwhals/blob/6d524ba04fca6fe2d6d25bdd69f75fabf1d79039/narwhals/utils.py#L1565-L1596
61+
"""
62+
63+
__slots__ = ("time_units", "time_zones")
64+
65+
time_units: frozenset[TimeUnit]
66+
time_zones: frozenset[str | None]
67+
68+
@staticmethod
69+
def from_time_unit_and_time_zone(
70+
time_unit: TimeUnit | Iterable[TimeUnit] | None,
71+
time_zone: str | timezone | Iterable[str | timezone | None] | None,
72+
/,
73+
) -> Datetime:
74+
units, zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
75+
return Datetime(time_units=frozenset(units), time_zones=frozenset(zones))
76+
77+
78+
class Matches(Selector):
79+
__slots__ = ("pattern",)
80+
81+
pattern: re.Pattern[str]
82+
83+
@staticmethod
84+
def from_string(pattern: str, /) -> Matches:
85+
return Matches(pattern=re.compile(pattern))
86+
87+
88+
class Numeric(Selector): ...
89+
90+
91+
class String(Selector): ...
92+
93+
94+
def all() -> DummySelector:
95+
return All().to_selector().to_narwhals()
96+
97+
98+
def by_dtype(
99+
*dtypes: DType | type[DType] | Iterable[DType | type[DType]],
100+
) -> DummySelector:
101+
return ByDType.from_dtypes(*dtypes).to_selector().to_narwhals()
102+
103+
104+
def boolean() -> DummySelector:
105+
return Boolean().to_selector().to_narwhals()
106+
107+
108+
def categorical() -> DummySelector:
109+
return Categorical().to_selector().to_narwhals()
110+
111+
112+
def datetime(
113+
time_unit: TimeUnit | Iterable[TimeUnit] | None = None,
114+
time_zone: str | timezone | Iterable[str | timezone | None] | None = ("*", None),
115+
) -> DummySelector:
116+
return (
117+
Datetime.from_time_unit_and_time_zone(time_unit, time_zone)
118+
.to_selector()
119+
.to_narwhals()
120+
)
121+
122+
123+
def matches(pattern: str) -> DummySelector:
124+
return Matches.from_string(pattern).to_selector().to_narwhals()
125+
126+
127+
def numeric() -> DummySelector:
128+
return Numeric().to_selector().to_narwhals()
129+
130+
131+
def string() -> DummySelector:
132+
return String().to_selector().to_narwhals()
133+
134+
135+
def _flatten_hash_safe(iterable: Iterable[T | Iterable[T]], /) -> Iterator[T]:
136+
"""Fully unwrap all levels of nesting.
137+
138+
Aiming to reduce the chances of passing an unhashable argument.
139+
"""
140+
for element in iterable:
141+
if isinstance(element, Iterable) and not is_iterable_reject(element):
142+
yield from _flatten_hash_safe(element)
143+
else:
144+
yield element # type: ignore[misc]

narwhals/_plan/typing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
from narwhals._plan.common import ExprIR, Function
1010
from narwhals._plan.functions import RollingWindow
1111

12-
__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT"]
12+
__all__ = ["FunctionT", "LeftT", "OperatorT", "RightT", "RollingT", "SelectorOperatorT"]
1313

1414

1515
FunctionT = TypeVar("FunctionT", bound="Function")
1616
RollingT = TypeVar("RollingT", bound="RollingWindow")
1717
LeftT = TypeVar("LeftT", bound="ExprIR", default="ExprIR")
1818
OperatorT = TypeVar("OperatorT", bound="ops.Operator", default="ops.Operator")
1919
RightT = TypeVar("RightT", bound="ExprIR", default="ExprIR")
20+
21+
SelectorOperatorT = TypeVar(
22+
"SelectorOperatorT", bound="ops.SelectorOperator", default="ops.SelectorOperator"
23+
)

0 commit comments

Comments
 (0)