Skip to content

Commit ee535b4

Browse files
authored
fix: nw.len().over was unnecessarily raising for pandas-like (#2372)
1 parent 7ae5a8c commit ee535b4

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

narwhals/_dask/expr.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,18 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
644644
message=".*`meta` is not specified",
645645
category=UserWarning,
646646
)
647-
res_native = df.native.groupby(partition_by)[
648-
list(output_names)
649-
].transform(dask_function_name, **self._call_kwargs)
647+
grouped = df.native.groupby(partition_by)
648+
if dask_function_name == "size":
649+
if len(output_names) != 1: # pragma: no cover
650+
msg = "Safety check failed, please report a bug."
651+
raise AssertionError(msg)
652+
res_native = grouped.transform(
653+
dask_function_name, **self._call_kwargs
654+
).to_frame(output_names[0])
655+
else:
656+
res_native = grouped[list(output_names)].transform(
657+
dask_function_name, **self._call_kwargs
658+
)
650659
result_frame = df._with_native(
651660
res_native.rename(columns=dict(zip(output_names, aliases)))
652661
).native

narwhals/_pandas_like/expr.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def cum_sum(self: Self, *, reverse: bool) -> Self:
198198
def shift(self: Self, n: int) -> Self:
199199
return self._reuse_series("shift", call_kwargs={"n": n})
200200

201-
def over(
201+
def over( # noqa: PLR0915
202202
self: Self,
203203
partition_by: Sequence[str],
204204
order_by: Sequence[str] | None,
@@ -265,21 +265,25 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
265265
elif reverse:
266266
columns = list(set(partition_by).union(output_names))
267267
df = df[columns][::-1]
268+
grouped = df._native_frame.groupby(partition_by)
268269
if function_name.startswith("rolling"):
269-
rolling = df._native_frame.groupby(partition_by)[
270-
list(output_names)
271-
].rolling(**pandas_kwargs)
270+
rolling = grouped[list(output_names)].rolling(**pandas_kwargs)
272271
assert pandas_function_name is not None # help mypy # noqa: S101
273272
if pandas_function_name in {"std", "var"}:
274273
res_native = getattr(rolling, pandas_function_name)(
275274
ddof=self._call_kwargs["ddof"]
276275
)
277276
else:
278277
res_native = getattr(rolling, pandas_function_name)()
278+
elif function_name == "len":
279+
if len(output_names) != 1: # pragma: no cover
280+
msg = "Safety check failed, please report a bug."
281+
raise AssertionError(msg)
282+
res_native = grouped.transform("size").to_frame(aliases[0])
279283
else:
280-
res_native = df._native_frame.groupby(partition_by)[
281-
list(output_names)
282-
].transform(pandas_function_name, **pandas_kwargs)
284+
res_native = grouped[list(output_names)].transform(
285+
pandas_function_name, **pandas_kwargs
286+
)
283287
result_frame = df._with_native(res_native).rename(
284288
dict(zip(output_names, aliases))
285289
)

tests/expr_and_series/over_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,17 @@ def test_over_without_partition_by(
428428
)
429429
expected = {"a": [1, 2, -1], "b": [1, 3, 4], "i": [0, 1, 2]}
430430
assert_equal_data(result, expected)
431+
432+
433+
def test_len_over_2369(constructor: Constructor, request: pytest.FixtureRequest) -> None:
434+
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3):
435+
pytest.skip()
436+
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 5):
437+
pytest.skip()
438+
if any(x in str(constructor) for x in ("modin",)):
439+
# https://github.com/modin-project/modin/issues/7508
440+
request.applymarker(pytest.mark.xfail)
441+
df = nw.from_native(constructor({"a": [1, 2, 4], "b": ["x", "x", "y"]}))
442+
result = df.with_columns(a_len_per_group=nw.len().over("b")).sort("a")
443+
expected = {"a": [1, 2, 4], "b": ["x", "x", "y"], "a_len_per_group": [2, 2, 1]}
444+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)