Skip to content

Commit 15c87ea

Browse files
committed
feat(expr-ir): Add meta.as_selector, parse_into_selector_ir
Still have some translations missing `by_index` will mean updating `matches_column` to *also* pass in the schema index
1 parent ab7330a commit 15c87ea

File tree

6 files changed

+80
-21
lines changed

6 files changed

+80
-21
lines changed

narwhals/_plan/_expr_ir.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from narwhals._plan.common import replace
99
from narwhals._plan.options import ExprIROptions
1010
from narwhals._plan.typing import ExprIRT
11+
from narwhals.exceptions import InvalidOperationError
1112
from narwhals.utils import Version
1213

1314
if TYPE_CHECKING:
@@ -59,6 +60,10 @@ def to_narwhals(self, version: Version = Version.MAIN) -> Expr:
5960
tp = expr.Expr if version is Version.MAIN else expr.ExprV1
6061
return tp._from_ir(self)
6162

63+
def to_selector_ir(self) -> SelectorIR:
64+
msg = f"cannot turn `{self!r}` into a selector"
65+
raise InvalidOperationError(msg)
66+
6267
@property
6368
def is_scalar(self) -> bool:
6469
return False
@@ -201,6 +206,9 @@ def matches_column(self, name: str, dtype: DType) -> bool:
201206
"""
202207
raise NotImplementedError(type(self))
203208

209+
def to_selector_ir(self) -> Self:
210+
return self
211+
204212

205213
class NamedIR(Immutable, Generic[ExprIRT]):
206214
"""Post-projection expansion wrapper for `ExprIR`.

narwhals/_plan/_guards.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
if TYPE_CHECKING:
1212
from typing_extensions import TypeIs
1313

14-
from narwhals._plan import expressions as ir
14+
from narwhals._plan import expr, expressions as ir
1515
from narwhals._plan.compliant.series import CompliantSeries
1616
from narwhals._plan.expr import Expr
1717
from narwhals._plan.series import Series
@@ -58,6 +58,10 @@ def is_expr(obj: Any) -> TypeIs[Expr]:
5858
return isinstance(obj, _expr().Expr)
5959

6060

61+
def is_selector(obj: Any) -> TypeIs[expr.Selector]:
62+
return isinstance(obj, _expr().Selector)
63+
64+
6165
def is_column(obj: Any) -> TypeIs[Expr]:
6266
"""Indicate if the given object is a basic/unaliased column."""
6367
return is_expr(obj) and obj.meta.is_column()

narwhals/_plan/_parse.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66
from itertools import chain
77
from typing import TYPE_CHECKING
88

9-
from narwhals._plan._guards import is_expr, is_into_expr_column, is_iterable_reject
9+
from narwhals._plan._guards import (
10+
is_expr,
11+
is_into_expr_column,
12+
is_iterable_reject,
13+
is_selector,
14+
)
1015
from narwhals._plan.exceptions import (
1116
invalid_into_expr_error,
1217
is_iterable_pandas_error,
1318
is_iterable_polars_error,
1419
)
20+
from narwhals._utils import qualified_type_name
1521
from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series
1622
from narwhals.exceptions import InvalidOperationError
1723

@@ -22,8 +28,10 @@
2228
import polars as pl
2329
from typing_extensions import TypeAlias, TypeIs
2430

25-
from narwhals._plan.expressions import ExprIR
31+
from narwhals._plan.expr import Expr
32+
from narwhals._plan.expressions import ExprIR, SelectorIR
2633
from narwhals._plan.typing import (
34+
ColumnNameOrSelector,
2735
IntoExpr,
2836
IntoExprColumn,
2937
OneOrIterable,
@@ -124,6 +132,21 @@ def parse_into_expr_ir(
124132
return expr._ir
125133

126134

135+
def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR:
136+
if is_selector(input):
137+
selector = input
138+
elif isinstance(input, str):
139+
from narwhals._plan import selectors as cs
140+
141+
selector = cs.by_name(input)
142+
elif is_expr(input):
143+
selector = input.meta.as_selector()
144+
else:
145+
msg = f"cannot turn {qualified_type_name(input)!r} into selector"
146+
raise TypeError(msg)
147+
return selector._ir
148+
149+
127150
def parse_into_seq_of_expr_ir(
128151
first_input: OneOrIterable[IntoExpr] = (),
129152
*more_inputs: IntoExpr | _RaisesInvalidIntoExprError,

narwhals/_plan/expressions/expr.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from narwhals._plan._expr_ir import ExprIR, SelectorIR
1010
from narwhals._plan.common import flatten_hash_safe
1111
from narwhals._plan.exceptions import function_expr_invalid_operation_error
12+
from narwhals._plan.expressions import selectors as cs
1213
from narwhals._plan.options import ExprIROptions
1314
from narwhals._plan.typing import (
1415
FunctionT_co,
@@ -32,7 +33,6 @@
3233
from narwhals._plan.compliant.typing import Ctx, FrameT_contra, R_co
3334
from narwhals._plan.expressions.functions import MapBatches # noqa: F401
3435
from narwhals._plan.expressions.literal import LiteralValue
35-
from narwhals._plan.expressions.selectors import Selector
3636
from narwhals._plan.expressions.window import Window
3737
from narwhals._plan.options import FunctionOptions, SortMultipleOptions, SortOptions
3838
from narwhals.dtypes import DType
@@ -100,6 +100,9 @@ class Column(ExprIR, config=ExprIROptions.namespaced("col")):
100100
def __repr__(self) -> str:
101101
return f"col({self.name!r})"
102102

103+
def to_selector_ir(self) -> RootSelector:
104+
return cs.ByName.from_name(self.name).to_selector_ir()
105+
103106

104107
class _ColumnSelection(ExprIR, config=ExprIROptions.no_dispatch()):
105108
"""Nodes which can resolve to `Column`(s) with a `Schema`."""
@@ -112,7 +115,11 @@ class Columns(_ColumnSelection):
112115
def __repr__(self) -> str:
113116
return f"cols({list(self.names)!r})"
114117

118+
def to_selector_ir(self) -> RootSelector:
119+
return cs.ByName.from_names(*self.names).to_selector_ir()
120+
115121

122+
# TODO @dangotbanned: Add `selectors.by_index`
116123
class Nth(_ColumnSelection):
117124
__slots__ = ("index",)
118125
index: int
@@ -121,6 +128,7 @@ def __repr__(self) -> str:
121128
return f"nth({self.index})"
122129

123130

131+
# TODO @dangotbanned: Add `selectors.by_index`
124132
class IndexColumns(_ColumnSelection):
125133
__slots__ = ("indices",)
126134
indices: Seq[int]
@@ -133,7 +141,11 @@ class All(_ColumnSelection):
133141
def __repr__(self) -> str:
134142
return "all()"
135143

144+
def to_selector_ir(self) -> RootSelector:
145+
return cs.All().to_selector_ir()
146+
136147

148+
# TODO @dangotbanned: Add `selectors.exclude`
137149
class Exclude(_ColumnSelection, child=("expr",)):
138150
__slots__ = ("expr", "names")
139151
expr: ExprIR
@@ -450,7 +462,7 @@ class RootSelector(SelectorIR):
450462
"""A single selector expression."""
451463

452464
__slots__ = ("selector",)
453-
selector: Selector
465+
selector: cs.Selector
454466

455467
def __repr__(self) -> str:
456468
return f"{self.selector!r}"

narwhals/_plan/expressions/selectors.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939

4040
class Selector(Immutable):
41-
def to_selector(self) -> RootSelector:
41+
def to_selector_ir(self) -> RootSelector:
4242
from narwhals._plan.expressions.expr import RootSelector
4343

4444
return RootSelector(selector=self)
@@ -248,14 +248,14 @@ def matches_column(self, name: str, dtype: DType) -> bool:
248248

249249

250250
def all() -> expr.Selector:
251-
return All().to_selector().to_narwhals()
251+
return All().to_selector_ir().to_narwhals()
252252

253253

254254
def array(
255255
inner: expr.Selector | None = None, *, size: int | None = None
256256
) -> expr.Selector:
257257
s_ir = inner._ir if inner is not None else None
258-
return Array(inner=s_ir, size=size).to_selector().to_narwhals()
258+
return Array(inner=s_ir, size=size).to_selector_ir().to_narwhals()
259259

260260

261261
def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selector:
@@ -267,15 +267,15 @@ def by_name(*names: OneOrIterable[str]) -> expr.Selector:
267267
sel = ByName.from_name(names[0])
268268
else:
269269
sel = ByName.from_names(*names)
270-
return sel.to_selector().to_narwhals()
270+
return sel.to_selector_ir().to_narwhals()
271271

272272

273273
def boolean() -> expr.Selector:
274-
return Boolean().to_selector().to_narwhals()
274+
return Boolean().to_selector_ir().to_narwhals()
275275

276276

277277
def categorical() -> expr.Selector:
278-
return Categorical().to_selector().to_narwhals()
278+
return Categorical().to_selector_ir().to_narwhals()
279279

280280

281281
def datetime(
@@ -284,38 +284,38 @@ def datetime(
284284
) -> expr.Selector:
285285
return (
286286
Datetime.from_time_unit_and_time_zone(time_unit, time_zone)
287-
.to_selector()
287+
.to_selector_ir()
288288
.to_narwhals()
289289
)
290290

291291

292292
def list(inner: expr.Selector | None = None) -> expr.Selector:
293293
s_ir = inner._ir if inner is not None else None
294-
return List(inner=s_ir).to_selector().to_narwhals()
294+
return List(inner=s_ir).to_selector_ir().to_narwhals()
295295

296296

297297
def duration(time_unit: OneOrIterable[TimeUnit] | None = None) -> expr.Selector:
298-
return Duration.from_time_unit(time_unit).to_selector().to_narwhals()
298+
return Duration.from_time_unit(time_unit).to_selector_ir().to_narwhals()
299299

300300

301301
def enum() -> expr.Selector:
302-
return Enum().to_selector().to_narwhals()
302+
return Enum().to_selector_ir().to_narwhals()
303303

304304

305305
def matches(pattern: str) -> expr.Selector:
306-
return Matches.from_string(pattern).to_selector().to_narwhals()
306+
return Matches.from_string(pattern).to_selector_ir().to_narwhals()
307307

308308

309309
def numeric() -> expr.Selector:
310-
return Numeric().to_selector().to_narwhals()
310+
return Numeric().to_selector_ir().to_narwhals()
311311

312312

313313
def string() -> expr.Selector:
314-
return String().to_selector().to_narwhals()
314+
return String().to_selector_ir().to_narwhals()
315315

316316

317317
def struct() -> expr.Selector:
318-
return Struct().to_selector().to_narwhals()
318+
return Struct().to_selector_ir().to_narwhals()
319319

320320

321321
_HASH_SENSITIVE_TO_SELECTOR: Mapping[type[DType], Callable[[], expr.Selector]] = {
@@ -343,7 +343,7 @@ def _from_dtypes(*by_dtypes: OneOrIterable[DType | type[DType]]) -> expr.Selecto
343343
else:
344344
dtypes.append(dtype) # type: ignore[arg-type]
345345
if dtypes:
346-
dtype_selector = ByDType(dtypes=frozenset(dtypes)).to_selector().to_narwhals()
346+
dtype_selector = ByDType(dtypes=frozenset(dtypes)).to_selector_ir().to_narwhals()
347347
selectors.appendleft(dtype_selector)
348348
it = iter(selectors)
349349
return reduce(operator.or_, it, next(it))

narwhals/_plan/meta.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
from narwhals._plan._guards import is_literal
1515
from narwhals._plan.expressions.literal import is_literal_scalar
1616
from narwhals._plan.expressions.namespace import IRNamespace
17-
from narwhals.exceptions import ComputeError
17+
from narwhals.exceptions import ComputeError, InvalidOperationError
1818
from narwhals.utils import Version
1919

2020
if TYPE_CHECKING:
2121
from collections.abc import Iterable, Iterator
2222

23+
from narwhals._plan import expr
24+
2325

2426
class MetaNamespace(IRNamespace):
2527
"""Methods to modify and traverse existing expressions."""
@@ -75,6 +77,16 @@ def root_names(self) -> list[str]:
7577
"""Get the root column names."""
7678
return list(_expr_to_leaf_column_names_iter(self._ir))
7779

80+
def as_selector(self) -> expr.Selector:
81+
"""Try to turn this expression into a selector.
82+
83+
Raises if the underlying expressions is not a column or selector.
84+
"""
85+
if not self.is_column_selection():
86+
msg = f"cannot turn `{self._ir!r}` into a selector"
87+
raise InvalidOperationError(msg)
88+
return self._ir.to_selector_ir().to_narwhals()
89+
7890

7991
def _expr_to_leaf_column_names_iter(expr: ir.ExprIR, /) -> Iterator[str]:
8092
for e in _expr_to_leaf_column_exprs_iter(expr):

0 commit comments

Comments
 (0)