Skip to content

Commit afa1c32

Browse files
committed
feat: Start adding IntoExpr parsing
- `lit` deals with the `Series` case - Very basic tests
1 parent ebd0542 commit afa1c32

File tree

5 files changed

+272
-0
lines changed

5 files changed

+272
-0
lines changed

narwhals/_plan/common.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import datetime as dt
4+
from decimal import Decimal
35
from typing import TYPE_CHECKING
46
from typing import TypeVar
57

@@ -12,12 +14,15 @@
1214
from typing_extensions import Never
1315
from typing_extensions import Self
1416
from typing_extensions import TypeAlias
17+
from typing_extensions import TypeIs
1518
from typing_extensions import dataclass_transform
1619

1720
from narwhals._plan.dummy import DummyCompliantExpr
1821
from narwhals._plan.dummy import DummyExpr
22+
from narwhals._plan.dummy import DummySeries
1923
from narwhals._plan.expr import FunctionExpr
2024
from narwhals._plan.options import FunctionOptions
25+
from narwhals.typing import NonNestedLiteral
2126

2227
else:
2328
# NOTE: This isn't important to the proposal, just wanted IDE support
@@ -58,6 +63,9 @@ def decorator(cls_or_fn: T) -> T:
5863
Udf: TypeAlias = "Callable[[Any], Any]"
5964
"""Placeholder for `map_batches(function=...)`."""
6065

66+
IntoExprColumn: TypeAlias = "DummyExpr | DummySeries | str"
67+
IntoExpr: TypeAlias = "NonNestedLiteral | IntoExprColumn"
68+
6169

6270
@dataclass_transform(kw_only_default=True, frozen_default=True)
6371
class Immutable:
@@ -162,3 +170,37 @@ def to_function_expr(self, *inputs: ExprIR) -> FunctionExpr[Self]:
162170
# Feel like it should be the union of `input` & `function`
163171
PLACEHOLDER = FunctionOptions.default() # noqa: N806
164172
return FunctionExpr(input=inputs, function=self, options=PLACEHOLDER)
173+
174+
175+
_NON_NESTED_LITERAL_TPS = (
176+
int,
177+
float,
178+
str,
179+
dt.date,
180+
dt.time,
181+
dt.timedelta,
182+
bytes,
183+
Decimal,
184+
)
185+
186+
187+
def is_non_nested_literal(obj: Any) -> TypeIs[NonNestedLiteral]:
188+
return obj is None or isinstance(obj, _NON_NESTED_LITERAL_TPS)
189+
190+
191+
def is_expr(obj: Any) -> TypeIs[DummyExpr]:
192+
from narwhals._plan.dummy import DummyExpr
193+
194+
return isinstance(obj, DummyExpr)
195+
196+
197+
def is_series(obj: Any) -> TypeIs[DummySeries]:
198+
from narwhals._plan.dummy import DummySeries
199+
200+
return isinstance(obj, DummySeries)
201+
202+
203+
def is_iterable_reject(obj: Any) -> TypeIs[str | bytes | DummySeries]:
204+
from narwhals._plan.dummy import DummySeries
205+
206+
return isinstance(obj, (str, bytes, DummySeries))

narwhals/_plan/demo.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
from narwhals._plan import aggregation as agg
77
from narwhals._plan import boolean
8+
from narwhals._plan import expr_parsing as parse
89
from narwhals._plan import functions as F # noqa: N812
10+
from narwhals._plan.common import ExprIR
11+
from narwhals._plan.common import IntoExpr
12+
from narwhals._plan.common import is_non_nested_literal
913
from narwhals._plan.dummy import DummySeries
1014
from narwhals._plan.expr import All
1115
from narwhals._plan.expr import Column
@@ -57,6 +61,9 @@ def lit(
5761
return SeriesLiteral(value=value).to_literal().to_narwhals()
5862
if dtype is None or not isinstance(dtype, DType):
5963
dtype = Version.MAIN.dtypes.Unknown()
64+
if not is_non_nested_literal(value):
65+
msg = f"{type(value).__name__!r} is not supported in `nw.lit`, got: {value!r}."
66+
raise TypeError(msg)
6067
return ScalarLiteral(value=value, dtype=dtype).to_literal().to_narwhals()
6168

6269

@@ -170,3 +177,9 @@ def ensure_orderable_rules(*exprs: DummyExpr) -> tuple[DummyExpr, ...]:
170177
if not _is_order_enforcing_previous(previous):
171178
raise _order_dependent_error(node)
172179
return exprs
180+
181+
182+
def select_context(
183+
*exprs: IntoExpr | t.Iterable[IntoExpr], **named_exprs: IntoExpr
184+
) -> tuple[ExprIR, ...]:
185+
return parse.parse_into_seq_of_expr_ir(*exprs, **named_exprs)

narwhals/_plan/expr_parsing.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
from __future__ import annotations
2+
3+
# ruff: noqa: A002
4+
from typing import TYPE_CHECKING
5+
from typing import Iterable
6+
from typing import Sequence
7+
from typing import TypeVar
8+
9+
from narwhals._plan.common import is_expr
10+
from narwhals._plan.common import is_iterable_reject
11+
from narwhals.dependencies import get_polars
12+
from narwhals.dependencies import is_pandas_dataframe
13+
from narwhals.dependencies import is_pandas_series
14+
from narwhals.exceptions import InvalidIntoExprError
15+
16+
if TYPE_CHECKING:
17+
from typing import Any
18+
from typing import Iterator
19+
20+
from typing_extensions import TypeAlias
21+
from typing_extensions import TypeIs
22+
23+
from narwhals._plan.common import ExprIR
24+
from narwhals._plan.common import IntoExpr
25+
from narwhals._plan.common import Seq
26+
from narwhals.dtypes import DType
27+
28+
T = TypeVar("T")
29+
30+
_RaisesInvalidIntoExprError: TypeAlias = "Any"
31+
"""
32+
Placeholder for multiple `Iterable[IntoExpr]`.
33+
34+
We only support cases `a`, `b`, but the typing for most contexts is more permissive:
35+
36+
>>> import polars as pl
37+
>>> df = pl.DataFrame({"one": ["A", "B", "A"], "two": [1, 2, 3], "three": [4, 5, 6]})
38+
>>> a = ("one", "two")
39+
>>> b = (["one", "two"],)
40+
>>>
41+
>>> c = ("one", ["two"])
42+
>>> d = (["one"], "two")
43+
>>> [df.select(*into) for into in (a, b, c, d)]
44+
[shape: (3, 2)
45+
┌─────┬─────┐
46+
│ one ┆ two │
47+
│ --- ┆ --- │
48+
│ str ┆ i64 │
49+
╞═════╪═════╡
50+
│ A ┆ 1 │
51+
│ B ┆ 2 │
52+
│ A ┆ 3 │
53+
└─────┴─────┘,
54+
shape: (3, 2)
55+
┌─────┬─────┐
56+
│ one ┆ two │
57+
│ --- ┆ --- │
58+
│ str ┆ i64 │
59+
╞═════╪═════╡
60+
│ A ┆ 1 │
61+
│ B ┆ 2 │
62+
│ A ┆ 3 │
63+
└─────┴─────┘,
64+
shape: (3, 2)
65+
┌─────┬───────────┐
66+
│ one ┆ literal │
67+
│ --- ┆ --- │
68+
│ str ┆ list[str] │
69+
╞═════╪═══════════╡
70+
│ A ┆ ["two"] │
71+
│ B ┆ ["two"] │
72+
│ A ┆ ["two"] │
73+
└─────┴───────────┘,
74+
shape: (3, 2)
75+
┌───────────┬─────┐
76+
│ literal ┆ two │
77+
│ --- ┆ --- │
78+
│ list[str] ┆ i64 │
79+
╞═══════════╪═════╡
80+
│ ["one"] ┆ 1 │
81+
│ ["one"] ┆ 2 │
82+
│ ["one"] ┆ 3 │
83+
└───────────┴─────┘]
84+
"""
85+
86+
87+
def parse_into_expr_ir(
88+
input: IntoExpr, *, str_as_lit: bool = False, dtype: DType | None = None
89+
) -> ExprIR:
90+
"""Parse a single input into an `ExprIR` node."""
91+
from narwhals._plan import demo as nwd
92+
93+
if is_expr(input):
94+
expr = input
95+
elif isinstance(input, str) and not str_as_lit:
96+
expr = nwd.col(input)
97+
else:
98+
expr = nwd.lit(input, dtype=dtype)
99+
return expr._ir
100+
101+
102+
def parse_into_seq_of_expr_ir(
103+
first_input: IntoExpr | Iterable[IntoExpr] = (),
104+
*more_inputs: IntoExpr | _RaisesInvalidIntoExprError,
105+
**named_inputs: IntoExpr,
106+
) -> Seq[ExprIR]:
107+
"""Parse variadic inputs into a flat sequence of `ExprIR` nodes."""
108+
return tuple(_parse_into_iter_expr_ir(first_input, *more_inputs, **named_inputs))
109+
110+
111+
def _parse_into_iter_expr_ir(
112+
first_input: IntoExpr | Iterable[IntoExpr],
113+
*more_inputs: IntoExpr,
114+
**named_inputs: IntoExpr,
115+
) -> Iterator[ExprIR]:
116+
if not _is_empty_sequence(first_input):
117+
# NOTE: These need to be separated to introduce an intersection type
118+
# Otherwise, `str | bytes` always passes through typing
119+
if _is_iterable(first_input) and not is_iterable_reject(first_input):
120+
if more_inputs:
121+
raise _invalid_into_expr_error(first_input, more_inputs, named_inputs)
122+
else:
123+
yield from _parse_positional_inputs(first_input)
124+
else:
125+
yield parse_into_expr_ir(first_input)
126+
else:
127+
# NOTE: Passthrough case for no inputs - but gets skipped when calling next
128+
yield from ()
129+
if more_inputs:
130+
yield from _parse_positional_inputs(more_inputs)
131+
if named_inputs:
132+
yield from _parse_named_inputs(named_inputs)
133+
134+
135+
def _parse_positional_inputs(inputs: Iterable[IntoExpr], /) -> Iterator[ExprIR]:
136+
for into in inputs:
137+
yield parse_into_expr_ir(into)
138+
139+
140+
def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR]:
141+
from narwhals._plan.expr import Alias
142+
143+
for name, input in named_inputs.items():
144+
yield Alias(expr=parse_into_expr_ir(input), name=name)
145+
146+
147+
def _is_iterable(obj: Iterable[T] | Any) -> TypeIs[Iterable[T]]:
148+
if is_pandas_dataframe(obj) or is_pandas_series(obj):
149+
msg = f"Expected Narwhals class or scalar, got: {type(obj)}. Perhaps you forgot a `nw.from_native` somewhere?"
150+
raise TypeError(msg)
151+
if _is_polars(obj):
152+
msg = (
153+
f"Expected Narwhals class or scalar, got: {type(obj)}.\n\n"
154+
"Hint: Perhaps you\n"
155+
"- forgot a `nw.from_native` somewhere?\n"
156+
"- used `pl.col` instead of `nw.col`?"
157+
)
158+
raise TypeError(msg)
159+
return isinstance(obj, Iterable)
160+
161+
162+
def _is_empty_sequence(obj: Any) -> bool:
163+
return isinstance(obj, Sequence) and not obj
164+
165+
166+
def _is_polars(obj: Any) -> bool:
167+
return (pl := get_polars()) is not None and isinstance(
168+
obj, (pl.Series, pl.Expr, pl.DataFrame, pl.LazyFrame)
169+
)
170+
171+
172+
def _invalid_into_expr_error(
173+
first_input: Any, more_inputs: Any, named_inputs: Any
174+
) -> InvalidIntoExprError:
175+
msg = (
176+
f"Passing both iterable and positional inputs is not supported.\n"
177+
f"Hint:\nInstead try collecting all arguments into a {type(first_input).__name__!r}\n"
178+
f"{first_input!r}\n{more_inputs!r}\n{named_inputs!r}"
179+
)
180+
return InvalidIntoExprError(msg)

tests/plan/__init__.py

Whitespace-only changes.

tests/plan/expr_parsing_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from typing import Iterable
5+
6+
import pytest
7+
8+
import narwhals as nw
9+
import narwhals._plan.demo as nwd
10+
from narwhals._plan.common import ExprIR
11+
12+
if TYPE_CHECKING:
13+
from narwhals._plan.common import IntoExpr
14+
from narwhals._plan.common import Seq
15+
16+
17+
@pytest.mark.parametrize(
18+
("exprs", "named_exprs"),
19+
[
20+
([nwd.col("a")], {}),
21+
(["a"], {}),
22+
([], {"a": "b"}),
23+
([], {"a": nwd.col("b")}),
24+
(["a", "b", nwd.col("c", "d", "e")], {"g": nwd.lit(1)}),
25+
([["a", "b", "c"]], {"q": nwd.lit(5, nw.Int8())}),
26+
(
27+
[[nwd.nth(1), nwd.nth(2, 3, 4)]],
28+
{"n": nwd.col("p").count(), "other n": nwd.len()},
29+
),
30+
],
31+
)
32+
def test_parsing(
33+
exprs: Seq[IntoExpr | Iterable[IntoExpr]], named_exprs: dict[str, IntoExpr]
34+
) -> None:
35+
assert all(
36+
isinstance(node, ExprIR) for node in nwd.select_context(*exprs, **named_exprs)
37+
)

0 commit comments

Comments
 (0)