Skip to content

Commit 562e375

Browse files
authored
Revert "chore: simplify parsing expressions in group-by" (#3109)
Revert "chore: simplify parsing expressions in group-by (#3106)" This reverts commit 538973e.
1 parent 2ee5296 commit 562e375

File tree

8 files changed

+99
-86
lines changed

8 files changed

+99
-86
lines changed

narwhals/_compliant/series.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from typing_extensions import NotRequired, Self, TypedDict
3838

3939
from narwhals._compliant.dataframe import CompliantDataFrame
40-
from narwhals._compliant.expr import CompliantExpr, EagerExpr
4140
from narwhals._compliant.namespace import EagerNamespace
4241
from narwhals._utils import Implementation, Version, _LimitedContext
4342
from narwhals.dtypes import DType
@@ -97,8 +96,6 @@ def to_narwhals(self) -> Series[NativeSeriesT]:
9796
def _with_native(self, series: Any) -> Self: ...
9897
def _with_version(self, version: Version) -> Self: ...
9998

100-
def _to_expr(self) -> CompliantExpr[Any, Self]: ...
101-
10299
# NOTE: `polars`
103100
@property
104101
def dtype(self) -> DType: ...
@@ -246,9 +243,6 @@ def __narwhals_namespace__(
246243
self,
247244
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
248245

249-
def _to_expr(self) -> EagerExpr[Any, Any]:
250-
return self.__narwhals_namespace__()._expr._from_series(self) # type: ignore[no-any-return]
251-
252246
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
253247
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
254248
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:

narwhals/_expression_parsing.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from __future__ import annotations
66

77
from enum import Enum, auto
8+
from itertools import chain
89
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
910

1011
from narwhals._utils import is_compliant_expr, zip_strict
11-
from narwhals.dependencies import is_narwhals_series, is_numpy_array
12+
from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d
1213
from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError
1314

1415
if TYPE_CHECKING:
@@ -45,6 +46,13 @@ def is_series(obj: Any) -> TypeIs[Series[Any]]:
4546
return isinstance(obj, Series)
4647

4748

49+
def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]:
50+
from narwhals.expr import Expr
51+
from narwhals.series import Series
52+
53+
return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj)
54+
55+
4856
def combine_evaluate_output_names(
4957
*exprs: CompliantExpr[CompliantFrameT, Any],
5058
) -> EvalNames[CompliantFrameT]:
@@ -576,6 +584,14 @@ def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> No
576584
raise InvalidOperationError(msg)
577585

578586

587+
def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
588+
# Raise if any argument in `args` isn't an aggregation or literal.
589+
# For Series input, we don't raise (yet), we let such checks happen later,
590+
# as this function works lazily and so can't evaluate lengths.
591+
exprs = chain(args, kwargs.values())
592+
return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)
593+
594+
579595
def apply_n_ary_operation(
580596
plx: CompliantNamespaceAny,
581597
n_ary_function: Callable[..., CompliantExprAny],

narwhals/_polars/series.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import polars as pl
66

7-
from narwhals._polars.expr import PolarsExpr
87
from narwhals._polars.utils import (
98
BACKEND_VERSION,
109
SERIES_ACCEPTS_PD_INDEX,
@@ -151,10 +150,6 @@ def __init__(self, series: pl.Series, *, version: Version) -> None:
151150
self._native_series = series
152151
self._version = version
153152

154-
def _to_expr(self) -> PolarsExpr:
155-
# Polars can treat Series as Expr, so just pass down `self.native`.
156-
return PolarsExpr(self.native, version=self._version) # type: ignore[arg-type]
157-
158153
@property
159154
def _backend_version(self) -> tuple[int, ...]:
160155
return self._implementation._backend_version()

narwhals/dataframe.py

Lines changed: 50 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
from narwhals._expression_parsing import (
2020
ExprKind,
2121
check_expressions_preserve_length,
22-
is_expr,
22+
is_into_expr_eager,
2323
is_scalar_like,
24-
is_series,
2524
)
2625
from narwhals._typing import Arrow, Pandas, _LazyAllowedImpl, _LazyFrameCollectImpl
2726
from narwhals._utils import (
@@ -44,14 +43,14 @@
4443
supports_arrow_c_stream,
4544
zip_strict,
4645
)
47-
from narwhals.dependencies import is_numpy_array_1d, is_numpy_array_2d, is_pyarrow_table
46+
from narwhals.dependencies import is_numpy_array_2d, is_pyarrow_table
4847
from narwhals.exceptions import (
4948
ColumnNotFoundError,
5049
InvalidIntoExprError,
5150
InvalidOperationError,
5251
PerformanceWarning,
5352
)
54-
from narwhals.functions import _from_dict_no_backend, _is_into_schema, col, new_series
53+
from narwhals.functions import _from_dict_no_backend, _is_into_schema
5554
from narwhals.schema import Schema
5655
from narwhals.series import Series
5756
from narwhals.translate import to_native
@@ -68,10 +67,9 @@
6867
from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias
6968

7069
from narwhals._compliant import CompliantDataFrame, CompliantLazyFrame
71-
from narwhals._compliant.typing import CompliantExprAny
70+
from narwhals._compliant.typing import CompliantExprAny, EagerNamespaceAny
7271
from narwhals._translate import IntoArrowTable
7372
from narwhals._typing import EagerAllowed, IntoBackend, LazyAllowed, Polars
74-
from narwhals.expr import Expr
7573
from narwhals.group_by import GroupBy, LazyGroupBy
7674
from narwhals.typing import (
7775
AsofJoinStrategy,
@@ -89,7 +87,6 @@
8987
SingleIndexSelector,
9088
SizeUnit,
9189
UniqueKeepStrategy,
92-
_1DArray,
9390
_2DArray,
9491
)
9592

@@ -151,21 +148,18 @@ def _flatten_and_extract(
151148
# NOTE: Strings are interpreted as column names.
152149
out_exprs = []
153150
out_kinds = []
154-
ns = self.__narwhals_namespace__()
155-
all_exprs = chain(
156-
(self._parse_into_expr(x) for x in flatten(exprs)),
157-
(
158-
self._parse_into_expr(expr).alias(alias)
159-
for alias, expr in named_exprs.items()
160-
),
161-
)
162-
for expr in all_exprs:
163-
out_exprs.append(expr._to_compliant_expr(ns))
164-
out_kinds.append(ExprKind.from_expr(expr))
151+
for expr in flatten(exprs):
152+
compliant_expr = self._extract_compliant(expr)
153+
out_exprs.append(compliant_expr)
154+
out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False))
155+
for alias, expr in named_exprs.items():
156+
compliant_expr = self._extract_compliant(expr).alias(alias)
157+
out_exprs.append(compliant_expr)
158+
out_kinds.append(ExprKind.from_into_expr(expr, str_as_lit=False))
165159
return out_exprs, out_kinds
166160

167161
@abstractmethod
168-
def _parse_into_expr(self, arg: Any) -> Expr:
162+
def _extract_compliant(self, arg: Any) -> Any:
169163
raise NotImplementedError
170164

171165
def _extract_compliant_frame(self, other: Self | Any, /) -> Any:
@@ -482,15 +476,10 @@ class DataFrame(BaseFrame[DataFrameT]):
482476
def _compliant(self) -> CompliantDataFrame[Any, Any, DataFrameT, Self]:
483477
return self._compliant_frame
484478

485-
def _parse_into_expr(self, arg: Expr | Series[Any] | _1DArray | str) -> Expr:
486-
if isinstance(arg, str):
487-
return col(arg)
488-
if is_numpy_array_1d(arg):
489-
return new_series("", arg, backend=self.implementation)._to_expr()
490-
if is_series(arg):
491-
return arg._to_expr()
492-
if is_expr(arg):
493-
return arg
479+
def _extract_compliant(self, arg: Any) -> Any:
480+
if is_into_expr_eager(arg):
481+
plx: EagerNamespaceAny = self.__narwhals_namespace__()
482+
return plx.parse_into_expr(arg, str_as_lit=False)
494483
raise InvalidIntoExprError.from_invalid_type(type(arg))
495484

496485
@property
@@ -2298,34 +2287,39 @@ class LazyFrame(BaseFrame[LazyFrameT]):
22982287
def _compliant(self) -> CompliantLazyFrame[Any, LazyFrameT, Self]:
22992288
return self._compliant_frame
23002289

2301-
def _parse_into_expr(self, arg: Expr | str) -> Expr:
2302-
if isinstance(arg, str):
2303-
return col(arg)
2304-
if is_expr(arg):
2305-
if arg._metadata.n_orderable_ops:
2306-
msg = (
2307-
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
2308-
"Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n"
2309-
"For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n"
2310-
"`'date'` which orders your data, then replace:\n\n"
2311-
" nw.col('price').cum_sum()\n\n"
2312-
" with:\n\n"
2313-
" nw.col('price').cum_sum().over(order_by='date')\n"
2314-
" ^^^^^^^^^^^^^^^^^^^^^^\n\n"
2315-
"See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/."
2316-
)
2317-
raise InvalidOperationError(msg)
2318-
if arg._metadata.is_filtration:
2319-
msg = (
2320-
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
2321-
"followed by an aggregation.\n\n"
2322-
"Hints:\n"
2323-
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
2324-
"- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n"
2325-
" use `lf.select(nw.col('a').drop_nulls().sum())\n"
2326-
)
2327-
raise InvalidOperationError(msg)
2328-
return arg
2290+
def _extract_compliant(self, arg: Any) -> Any:
2291+
from narwhals.expr import Expr
2292+
from narwhals.series import Series
2293+
2294+
if isinstance(arg, Series): # pragma: no cover
2295+
msg = "Binary operations between Series and LazyFrame are not supported."
2296+
raise TypeError(msg)
2297+
if isinstance(arg, (Expr, str)):
2298+
if isinstance(arg, Expr):
2299+
if arg._metadata.n_orderable_ops:
2300+
msg = (
2301+
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
2302+
"Hint: To make the expression valid, use `.over` with `order_by` specified.\n\n"
2303+
"For example, if you wrote `nw.col('price').cum_sum()` and you have a column\n"
2304+
"`'date'` which orders your data, then replace:\n\n"
2305+
" nw.col('price').cum_sum()\n\n"
2306+
" with:\n\n"
2307+
" nw.col('price').cum_sum().over(order_by='date')\n"
2308+
" ^^^^^^^^^^^^^^^^^^^^^^\n\n"
2309+
"See https://narwhals-dev.github.io/narwhals/concepts/order_dependence/."
2310+
)
2311+
raise InvalidOperationError(msg)
2312+
if arg._metadata.is_filtration:
2313+
msg = (
2314+
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
2315+
"followed by an aggregation.\n\n"
2316+
"Hints:\n"
2317+
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
2318+
"- Instead of `lf.select(nw.col('a').drop_nulls()).select(nw.sum('a'))`,\n"
2319+
" use `lf.select(nw.col('a').drop_nulls().sum())\n"
2320+
)
2321+
raise InvalidOperationError(msg)
2322+
return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False)
23292323
raise InvalidIntoExprError.from_invalid_type(type(arg))
23302324

23312325
@property

narwhals/group_by.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from typing import TYPE_CHECKING, Any, Generic, TypeVar
44

5-
from narwhals._utils import tupleify
5+
from narwhals._expression_parsing import all_exprs_are_scalar_like
6+
from narwhals._utils import flatten, tupleify
67
from narwhals.exceptions import InvalidOperationError
78
from narwhals.typing import DataFrameT
89

@@ -71,15 +72,23 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT:
7172
2 b 3 2
7273
3 c 3 1
7374
"""
74-
compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs)
75-
if not all(x.is_scalar_like for x in kinds):
75+
flat_aggs = tuple(flatten(aggs))
76+
if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
7677
msg = (
7778
"Found expression which does not aggregate.\n\n"
7879
"All expressions passed to GroupBy.agg must aggregate.\n"
7980
"For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
8081
"but `df.group_by('a').agg(nw.col('b'))` is not."
8182
)
8283
raise InvalidOperationError(msg)
84+
plx = self._df.__narwhals_namespace__()
85+
compliant_aggs = (
86+
*(x._to_compliant_expr(plx) for x in flat_aggs),
87+
*(
88+
value.alias(key)._to_compliant_expr(plx)
89+
for key, value in named_aggs.items()
90+
),
91+
)
8392
return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
8493

8594
def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
@@ -157,13 +166,21 @@ def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT:
157166
|└─────┴─────┴─────┘|
158167
└───────────────────┘
159168
"""
160-
compliant_aggs, kinds = self._df._flatten_and_extract(*aggs, **named_aggs)
161-
if not all(x.is_scalar_like for x in kinds):
169+
flat_aggs = tuple(flatten(aggs))
170+
if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
162171
msg = (
163172
"Found expression which does not aggregate.\n\n"
164173
"All expressions passed to GroupBy.agg must aggregate.\n"
165174
"For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
166175
"but `df.group_by('a').agg(nw.col('b'))` is not."
167176
)
168177
raise InvalidOperationError(msg)
178+
plx = self._df.__narwhals_namespace__()
179+
compliant_aggs = (
180+
*(x._to_compliant_expr(plx) for x in flat_aggs),
181+
*(
182+
value.alias(key)._to_compliant_expr(plx)
183+
for key, value in named_aggs.items()
184+
),
185+
)
169186
return self._df._with_compliant(self._grouped.agg(*compliant_aggs))

narwhals/series.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import Iterable, Iterator, Mapping, Sequence
55
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, overload
66

7-
from narwhals._expression_parsing import ExprMetadata
87
from narwhals._utils import (
98
Implementation,
109
Version,
@@ -21,7 +20,6 @@
2120
from narwhals.dependencies import is_numpy_array, is_numpy_array_1d, is_numpy_scalar
2221
from narwhals.dtypes import _validate_dtype, _validate_into_dtype
2322
from narwhals.exceptions import ComputeError, InvalidOperationError
24-
from narwhals.expr import Expr
2523
from narwhals.series_cat import SeriesCatNamespace
2624
from narwhals.series_dt import SeriesDateTimeNamespace
2725
from narwhals.series_list import SeriesListNamespace
@@ -91,10 +89,6 @@ def _dataframe(self) -> type[DataFrame[Any]]:
9189

9290
return DataFrame
9391

94-
def _to_expr(self) -> Expr:
95-
md = ExprMetadata.selector_single()
96-
return Expr(lambda _plx: self._compliant._to_expr(), md)
97-
9892
def __init__(
9993
self, series: Any, *, level: Literal["full", "lazy", "interchange"]
10094
) -> None:

narwhals/stable/v1/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import narwhals as nw
77
from narwhals import exceptions, functions as nw_f
88
from narwhals._exceptions import issue_warning
9-
from narwhals._expression_parsing import is_expr
109
from narwhals._typing_compat import TypeVar, assert_never
1110
from narwhals._utils import (
1211
Implementation,
@@ -234,13 +233,17 @@ def __init__(self, df: Any, *, level: Literal["full", "lazy", "interchange"]) ->
234233
def _dataframe(self) -> type[DataFrame[Any]]:
235234
return DataFrame
236235

237-
def _parse_into_expr(self, arg: Expr | str) -> Expr: # type: ignore[override]
236+
def _extract_compliant(self, arg: Any) -> Any:
238237
# After v1, we raise when passing order-dependent, length-changing,
239238
# or filtration expressions to LazyFrame
240-
if isinstance(arg, str):
241-
return col(arg)
242-
if is_expr(arg):
243-
return arg
239+
from narwhals.expr import Expr
240+
from narwhals.series import Series
241+
242+
if isinstance(arg, Series): # pragma: no cover
243+
msg = "Mixing Series with LazyFrame is not supported."
244+
raise TypeError(msg)
245+
if isinstance(arg, (Expr, str)):
246+
return self.__narwhals_namespace__().parse_into_expr(arg, str_as_lit=False)
244247
raise InvalidIntoExprError.from_invalid_type(type(arg))
245248

246249
def collect(

tests/frame/group_by_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def test_group_by_shift_raises(constructor: Constructor) -> None:
364364
df_native = {"a": [1, 2, 3], "b": [1, 1, 2]}
365365
df = nw.from_native(constructor(df_native))
366366
with pytest.raises(InvalidOperationError, match="does not aggregate"):
367-
df.group_by("b").agg(nw.col("a").abs())
367+
df.group_by("b").agg(nw.col("a").shift(1))
368368

369369

370370
def test_double_same_aggregation(

0 commit comments

Comments
 (0)