Skip to content

Commit bde22ac

Browse files
authored
feat(expr-ir): Acero order_by, hashjoin , DataFrame.{filter,join}, Expr.is_{first,last}_distinct (#3173)
1 parent 2403f1b commit bde22ac

31 files changed

+1411
-140
lines changed

narwhals/_plan/_expr_ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,3 +304,7 @@ def is_column(self, *, allow_aliasing: bool = False) -> bool:
304304

305305
ir = self.expr
306306
return isinstance(ir, Column) and ((self.name == ir.name) or allow_aliasing)
307+
308+
309+
def named_ir(name: str, expr: ExprIRT, /) -> NamedIR[ExprIRT]:
310+
return NamedIR(expr=expr, name=name)

narwhals/_plan/_guards.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from narwhals._plan.compliant.series import CompliantSeries
1616
from narwhals._plan.expr import Expr
1717
from narwhals._plan.series import Series
18-
from narwhals._plan.typing import NativeSeriesT, Seq
18+
from narwhals._plan.typing import IntoExprColumn, NativeSeriesT, Seq
1919
from narwhals.typing import NonNestedLiteral
2020

2121
T = TypeVar("T")
@@ -67,6 +67,10 @@ def is_series(obj: Series[NativeSeriesT] | Any) -> TypeIs[Series[NativeSeriesT]]
6767
return isinstance(obj, _series().Series)
6868

6969

70+
def is_into_expr_column(obj: Any) -> TypeIs[IntoExprColumn]:
71+
return isinstance(obj, (str, _expr().Expr, _series().Series))
72+
73+
7074
def is_compliant_series(
7175
obj: CompliantSeries[NativeSeriesT] | Any,
7276
) -> TypeIs[CompliantSeries[NativeSeriesT]]:

narwhals/_plan/_parse.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from itertools import chain
77
from typing import TYPE_CHECKING
88

9-
from narwhals._plan._guards import is_expr, is_iterable_reject
9+
from narwhals._plan._guards import is_expr, is_into_expr_column, is_iterable_reject
1010
from narwhals._plan.exceptions import (
1111
invalid_into_expr_error,
1212
is_iterable_pandas_error,
1313
is_iterable_polars_error,
1414
)
1515
from narwhals.dependencies import get_polars, is_pandas_dataframe, is_pandas_series
16+
from narwhals.exceptions import InvalidOperationError
1617

1718
if TYPE_CHECKING:
1819
from collections.abc import Iterator
@@ -22,7 +23,13 @@
2223
from typing_extensions import TypeAlias, TypeIs
2324

2425
from narwhals._plan.expressions import ExprIR
25-
from narwhals._plan.typing import IntoExpr, IntoExprColumn, OneOrIterable, Seq
26+
from narwhals._plan.typing import (
27+
IntoExpr,
28+
IntoExprColumn,
29+
OneOrIterable,
30+
PartialSeries,
31+
Seq,
32+
)
2633
from narwhals.typing import IntoDType
2734

2835
T = TypeVar("T")
@@ -85,15 +92,33 @@
8592

8693

8794
def parse_into_expr_ir(
88-
input: IntoExpr, *, str_as_lit: bool = False, dtype: IntoDType | None = None
95+
input: IntoExpr | list[Any],
96+
*,
97+
str_as_lit: bool = False,
98+
list_as_series: PartialSeries | None = None,
99+
dtype: IntoDType | None = None,
89100
) -> ExprIR:
90-
"""Parse a single input into an `ExprIR` node."""
101+
"""Parse a single input into an `ExprIR` node.
102+
103+
Arguments:
104+
input: The input to be parsed as an expression.
105+
str_as_lit: Interpret string input as a string literal. If set to `False` (default),
106+
strings are parsed as column names.
107+
list_as_series: Interpret list input as a Series literal, using the provided constructor.
108+
If set to `None` (default), lists will raise when passed to `lit`.
109+
dtype: If the input is expected to resolve to a literal with a known dtype, pass
110+
this to the `lit` constructor.
111+
"""
91112
from narwhals._plan import col, lit
92113

93114
if is_expr(input):
94115
expr = input
95116
elif isinstance(input, str) and not str_as_lit:
96117
expr = col(input)
118+
elif isinstance(input, list):
119+
if list_as_series is None:
120+
raise TypeError(input)
121+
expr = lit(list_as_series(input))
97122
else:
98123
expr = lit(input, dtype=dtype)
99124
return expr._ir
@@ -105,50 +130,90 @@ def parse_into_seq_of_expr_ir(
105130
**named_inputs: IntoExpr,
106131
) -> Seq[ExprIR]:
107132
"""Parse variadic inputs into a flat sequence of `ExprIR` nodes."""
108-
return tuple(_parse_into_iter_expr_ir(first_input, *more_inputs, **named_inputs))
133+
return tuple(
134+
_parse_into_iter_expr_ir(
135+
first_input, *more_inputs, _list_as_series=None, **named_inputs
136+
)
137+
)
109138

110139

111140
def parse_predicates_constraints_into_expr_ir(
112-
first_predicate: OneOrIterable[IntoExprColumn] = (),
113-
*more_predicates: IntoExprColumn | _RaisesInvalidIntoExprError,
141+
first_predicate: OneOrIterable[IntoExprColumn] | list[bool] = (),
142+
*more_predicates: IntoExprColumn | list[bool] | _RaisesInvalidIntoExprError,
143+
_list_as_series: PartialSeries | None = None,
114144
**constraints: IntoExpr,
115145
) -> ExprIR:
116146
"""Parse variadic predicates and constraints into an `ExprIR` node.
117147
118148
The result is an AND-reduction of all inputs.
119149
"""
120-
all_predicates = _parse_into_iter_expr_ir(first_predicate, *more_predicates)
150+
all_predicates = _parse_into_iter_expr_ir(
151+
first_predicate, *more_predicates, _list_as_series=_list_as_series
152+
)
121153
if constraints:
122154
chained = chain(all_predicates, _parse_constraints(constraints))
123155
return _combine_predicates(chained)
124156
return _combine_predicates(all_predicates)
125157

126158

159+
def parse_sort_by_into_seq_of_expr_ir(
160+
by: OneOrIterable[IntoExprColumn] = (), *more_by: IntoExprColumn
161+
) -> Seq[ExprIR]:
162+
"""Parse `DataFrame.sort` and `Expr.sort_by` keys into a flat sequence of `ExprIR` nodes."""
163+
return tuple(_parse_sort_by_into_iter_expr_ir(by, more_by))
164+
165+
166+
# TODO @dangotbanned: Review the rejection predicate
167+
# It doesn't cover all length-changing expressions, only aggregations/literals
168+
def _parse_sort_by_into_iter_expr_ir(
169+
by: OneOrIterable[IntoExprColumn], more_by: Iterable[IntoExprColumn]
170+
) -> Iterator[ExprIR]:
171+
for e in _parse_into_iter_expr_ir(by, *more_by):
172+
if e.is_scalar:
173+
msg = f"All expressions sort keys must preserve length, but got:\n{e!r}"
174+
raise InvalidOperationError(msg)
175+
yield e
176+
177+
127178
def _parse_into_iter_expr_ir(
128-
first_input: OneOrIterable[IntoExpr], *more_inputs: IntoExpr, **named_inputs: IntoExpr
179+
first_input: OneOrIterable[IntoExpr],
180+
*more_inputs: IntoExpr | list[Any],
181+
_list_as_series: PartialSeries | None = None,
182+
**named_inputs: IntoExpr,
129183
) -> Iterator[ExprIR]:
130184
if not _is_empty_sequence(first_input):
131185
# NOTE: These need to be separated to introduce an intersection type
132186
# Otherwise, `str | bytes` always passes through typing
133187
if _is_iterable(first_input) and not is_iterable_reject(first_input):
134-
if more_inputs:
188+
if more_inputs and (
189+
_list_as_series is None or not isinstance(first_input, list)
190+
):
135191
raise invalid_into_expr_error(first_input, more_inputs, named_inputs)
192+
# NOTE: Ensures `first_input = [False, True, True] -> lit(Series([False, True, True]))`
193+
elif (
194+
_list_as_series is not None
195+
and isinstance(first_input, list)
196+
and not is_into_expr_column(first_input[0])
197+
):
198+
yield parse_into_expr_ir(first_input, list_as_series=_list_as_series)
136199
else:
137-
yield from _parse_positional_inputs(first_input)
200+
yield from _parse_positional_inputs(first_input, _list_as_series)
138201
else:
139-
yield parse_into_expr_ir(first_input)
202+
yield parse_into_expr_ir(first_input, list_as_series=_list_as_series)
140203
else:
141204
# NOTE: Passthrough case for no inputs - but gets skipped when calling next
142205
yield from ()
143206
if more_inputs:
144-
yield from _parse_positional_inputs(more_inputs)
207+
yield from _parse_positional_inputs(more_inputs, _list_as_series)
145208
if named_inputs:
146209
yield from _parse_named_inputs(named_inputs)
147210

148211

149-
def _parse_positional_inputs(inputs: Iterable[IntoExpr], /) -> Iterator[ExprIR]:
212+
def _parse_positional_inputs(
213+
inputs: Iterable[IntoExpr | list[Any]], /, list_as_series: PartialSeries | None = None
214+
) -> Iterator[ExprIR]:
150215
for into in inputs:
151-
yield parse_into_expr_ir(into)
216+
yield parse_into_expr_ir(into, list_as_series=list_as_series)
152217

153218

154219
def _parse_named_inputs(named_inputs: dict[str, IntoExpr], /) -> Iterator[ExprIR]:

0 commit comments

Comments
 (0)