66from itertools import chain
77from typing import TYPE_CHECKING
88
9- from narwhals ._plan ._guards import is_expr , is_iterable_reject
9+ from narwhals ._plan ._guards import is_expr , is_into_expr_column , is_iterable_reject
1010from narwhals ._plan .exceptions import (
1111 invalid_into_expr_error ,
1212 is_iterable_pandas_error ,
1313 is_iterable_polars_error ,
1414)
1515from narwhals .dependencies import get_polars , is_pandas_dataframe , is_pandas_series
16+ from narwhals .exceptions import InvalidOperationError
1617
1718if TYPE_CHECKING :
1819 from collections .abc import Iterator
2223 from typing_extensions import TypeAlias , TypeIs
2324
2425 from narwhals ._plan .expressions import ExprIR
25- from narwhals ._plan .typing import IntoExpr , IntoExprColumn , OneOrIterable , Seq
26+ from narwhals ._plan .typing import (
27+ IntoExpr ,
28+ IntoExprColumn ,
29+ OneOrIterable ,
30+ PartialSeries ,
31+ Seq ,
32+ )
2633 from narwhals .typing import IntoDType
2734
2835 T = TypeVar ("T" )
8592
8693
8794def parse_into_expr_ir (
88- input : IntoExpr , * , str_as_lit : bool = False , dtype : IntoDType | None = None
95+ input : IntoExpr | list [Any ],
96+ * ,
97+ str_as_lit : bool = False ,
98+ list_as_series : PartialSeries | None = None ,
99+ dtype : IntoDType | None = None ,
89100) -> ExprIR :
90- """Parse a single input into an `ExprIR` node."""
101+ """Parse a single input into an `ExprIR` node.
102+
103+ Arguments:
104+ input: The input to be parsed as an expression.
105+ str_as_lit: Interpret string input as a string literal. If set to `False` (default),
106+ strings are parsed as column names.
107+ list_as_series: Interpret list input as a Series literal, using the provided constructor.
108+ If set to `None` (default), lists will raise when passed to `lit`.
109+ dtype: If the input is expected to resolve to a literal with a known dtype, pass
110+ this to the `lit` constructor.
111+ """
91112 from narwhals ._plan import col , lit
92113
93114 if is_expr (input ):
94115 expr = input
95116 elif isinstance (input , str ) and not str_as_lit :
96117 expr = col (input )
118+ elif isinstance (input , list ):
119+ if list_as_series is None :
120+ raise TypeError (input )
121+ expr = lit (list_as_series (input ))
97122 else :
98123 expr = lit (input , dtype = dtype )
99124 return expr ._ir
@@ -105,50 +130,90 @@ def parse_into_seq_of_expr_ir(
105130 ** named_inputs : IntoExpr ,
106131) -> Seq [ExprIR ]:
107132 """Parse variadic inputs into a flat sequence of `ExprIR` nodes."""
108- return tuple (_parse_into_iter_expr_ir (first_input , * more_inputs , ** named_inputs ))
133+ return tuple (
134+ _parse_into_iter_expr_ir (
135+ first_input , * more_inputs , _list_as_series = None , ** named_inputs
136+ )
137+ )
109138
110139
111140def parse_predicates_constraints_into_expr_ir (
112- first_predicate : OneOrIterable [IntoExprColumn ] = (),
113- * more_predicates : IntoExprColumn | _RaisesInvalidIntoExprError ,
141+ first_predicate : OneOrIterable [IntoExprColumn ] | list [bool ] = (),
142+ * more_predicates : IntoExprColumn | list [bool ] | _RaisesInvalidIntoExprError ,
143+ _list_as_series : PartialSeries | None = None ,
114144 ** constraints : IntoExpr ,
115145) -> ExprIR :
116146 """Parse variadic predicates and constraints into an `ExprIR` node.
117147
118148 The result is an AND-reduction of all inputs.
119149 """
120- all_predicates = _parse_into_iter_expr_ir (first_predicate , * more_predicates )
150+ all_predicates = _parse_into_iter_expr_ir (
151+ first_predicate , * more_predicates , _list_as_series = _list_as_series
152+ )
121153 if constraints :
122154 chained = chain (all_predicates , _parse_constraints (constraints ))
123155 return _combine_predicates (chained )
124156 return _combine_predicates (all_predicates )
125157
126158
159+ def parse_sort_by_into_seq_of_expr_ir (
160+ by : OneOrIterable [IntoExprColumn ] = (), * more_by : IntoExprColumn
161+ ) -> Seq [ExprIR ]:
162+ """Parse `DataFrame.sort` and `Expr.sort_by` keys into a flat sequence of `ExprIR` nodes."""
163+ return tuple (_parse_sort_by_into_iter_expr_ir (by , more_by ))
164+
165+
166+ # TODO @dangotbanned: Review the rejection predicate
167+ # It doesn't cover all length-changing expressions, only aggregations/literals
168+ def _parse_sort_by_into_iter_expr_ir (
169+ by : OneOrIterable [IntoExprColumn ], more_by : Iterable [IntoExprColumn ]
170+ ) -> Iterator [ExprIR ]:
171+ for e in _parse_into_iter_expr_ir (by , * more_by ):
172+ if e .is_scalar :
173+ msg = f"All expressions sort keys must preserve length, but got:\n { e !r} "
174+ raise InvalidOperationError (msg )
175+ yield e
176+
177+
127178def _parse_into_iter_expr_ir (
128- first_input : OneOrIterable [IntoExpr ], * more_inputs : IntoExpr , ** named_inputs : IntoExpr
179+ first_input : OneOrIterable [IntoExpr ],
180+ * more_inputs : IntoExpr | list [Any ],
181+ _list_as_series : PartialSeries | None = None ,
182+ ** named_inputs : IntoExpr ,
129183) -> Iterator [ExprIR ]:
130184 if not _is_empty_sequence (first_input ):
131185 # NOTE: These need to be separated to introduce an intersection type
132186 # Otherwise, `str | bytes` always passes through typing
133187 if _is_iterable (first_input ) and not is_iterable_reject (first_input ):
134- if more_inputs :
188+ if more_inputs and (
189+ _list_as_series is None or not isinstance (first_input , list )
190+ ):
135191 raise invalid_into_expr_error (first_input , more_inputs , named_inputs )
192+ # NOTE: Ensures `first_input = [False, True, True] -> lit(Series([False, True, True]))`
193+ elif (
194+ _list_as_series is not None
195+ and isinstance (first_input , list )
196+ and not is_into_expr_column (first_input [0 ])
197+ ):
198+ yield parse_into_expr_ir (first_input , list_as_series = _list_as_series )
136199 else :
137- yield from _parse_positional_inputs (first_input )
200+ yield from _parse_positional_inputs (first_input , _list_as_series )
138201 else :
139- yield parse_into_expr_ir (first_input )
202+ yield parse_into_expr_ir (first_input , list_as_series = _list_as_series )
140203 else :
141204 # NOTE: Passthrough case for no inputs - but gets skipped when calling next
142205 yield from ()
143206 if more_inputs :
144- yield from _parse_positional_inputs (more_inputs )
207+ yield from _parse_positional_inputs (more_inputs , _list_as_series )
145208 if named_inputs :
146209 yield from _parse_named_inputs (named_inputs )
147210
148211
149- def _parse_positional_inputs (inputs : Iterable [IntoExpr ], / ) -> Iterator [ExprIR ]:
212+ def _parse_positional_inputs (
213+ inputs : Iterable [IntoExpr | list [Any ]], / , list_as_series : PartialSeries | None = None
214+ ) -> Iterator [ExprIR ]:
150215 for into in inputs :
151- yield parse_into_expr_ir (into )
216+ yield parse_into_expr_ir (into , list_as_series = list_as_series )
152217
153218
154219def _parse_named_inputs (named_inputs : dict [str , IntoExpr ], / ) -> Iterator [ExprIR ]:
0 commit comments