Skip to content

Commit 9607e5f

Browse files
authored
feat: add cum_count and cum_prod to PySpark and DuckDB (#2286)
1 parent 2bcc6bb commit 9607e5f

File tree

4 files changed

+131
-6
lines changed

4 files changed

+131
-6
lines changed

narwhals/_duckdb/expr.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
8080
)
8181

8282
def _cum_window_func(
83-
self, *, reverse: bool, func_name: Literal["sum", "max", "min"]
83+
self,
84+
*,
85+
reverse: bool,
86+
func_name: Literal["sum", "max", "min", "count", "product"],
8487
) -> WindowFunction:
8588
def func(window_inputs: WindowInputs) -> duckdb.Expression:
8689
order_by_sql = generate_order_by_sql(
@@ -516,6 +519,16 @@ def cum_min(self, *, reverse: bool) -> Self:
516519
self._cum_window_func(reverse=reverse, func_name="min")
517520
)
518521

522+
def cum_count(self, *, reverse: bool) -> Self:
523+
return self._with_window_function(
524+
self._cum_window_func(reverse=reverse, func_name="count")
525+
)
526+
527+
def cum_prod(self, *, reverse: bool) -> Self:
528+
return self._with_window_function(
529+
self._cum_window_func(reverse=reverse, func_name="product")
530+
)
531+
519532
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
520533
if center:
521534
half = (window_size - 1) // 2
@@ -580,5 +593,3 @@ def struct(self: Self) -> DuckDBExprStructNamespace:
580593
drop_nulls = not_implemented()
581594
unique = not_implemented()
582595
is_unique = not_implemented()
583-
cum_count = not_implemented()
584-
cum_prod = not_implemented()

narwhals/_spark_like/expr.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ def _with_window_function(
135135
return result
136136

137137
def _cum_window_func(
138-
self: Self, *, reverse: bool, func_name: Literal["sum", "max", "min"]
138+
self: Self,
139+
*,
140+
reverse: bool,
141+
func_name: Literal["sum", "max", "min", "count", "product"],
139142
) -> WindowFunction:
140143
def func(window_inputs: WindowInputs) -> Column:
141144
if reverse:
@@ -594,6 +597,16 @@ def cum_min(self, *, reverse: bool) -> Self:
594597
self._cum_window_func(reverse=reverse, func_name="min")
595598
)
596599

600+
def cum_count(self, *, reverse: bool) -> Self:
601+
return self._with_window_function(
602+
self._cum_window_func(reverse=reverse, func_name="count")
603+
)
604+
605+
def cum_prod(self, *, reverse: bool) -> Self:
606+
return self._with_window_function(
607+
self._cum_window_func(reverse=reverse, func_name="product")
608+
)
609+
597610
def fill_null(
598611
self,
599612
value: Any | None,
@@ -657,6 +670,4 @@ def struct(self: Self) -> SparkLikeExprStructNamespace:
657670

658671
drop_nulls = not_implemented()
659672
unique = not_implemented()
660-
cum_count = not_implemented()
661-
cum_prod = not_implemented()
662673
quantile = not_implemented()

tests/expr_and_series/cum_count_test.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import pytest
44

55
import narwhals.stable.v1 as nw
6+
from tests.utils import DUCKDB_VERSION
7+
from tests.utils import POLARS_VERSION
8+
from tests.utils import Constructor
69
from tests.utils import ConstructorEager
710
from tests.utils import assert_equal_data
811

@@ -36,3 +39,50 @@ def test_cum_count_series(constructor_eager: ConstructorEager) -> None:
3639
"reverse_cum_count": [3, 2, 1, 1],
3740
}
3841
assert_equal_data(result, expected)
42+
43+
44+
@pytest.mark.parametrize(
45+
("reverse", "expected_a"),
46+
[
47+
(False, [1, 1, 2]),
48+
(True, [1, 2, 1]),
49+
],
50+
)
51+
def test_lazy_cum_count_grouped(
52+
constructor: Constructor,
53+
request: pytest.FixtureRequest,
54+
*,
55+
reverse: bool,
56+
expected_a: list[int],
57+
) -> None:
58+
if "pyarrow_table" in str(constructor):
59+
# grouped window functions not yet supported
60+
request.applymarker(pytest.mark.xfail)
61+
if "modin" in str(constructor):
62+
pytest.skip(reason="probably bugged")
63+
if "dask" in str(constructor):
64+
# https://github.com/dask/dask/issues/11806
65+
request.applymarker(pytest.mark.xfail)
66+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
67+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
68+
):
69+
pytest.skip(reason="too old version")
70+
if "cudf" in str(constructor):
71+
# https://github.com/rapidsai/cudf/issues/18159
72+
request.applymarker(pytest.mark.xfail)
73+
74+
df = nw.from_native(
75+
constructor(
76+
{
77+
"arg entina": [None, 2, 3],
78+
"ban gkock": [1, 0, 2],
79+
"i ran": [0, 1, 2],
80+
"g": [1, 1, 1],
81+
}
82+
)
83+
)
84+
result = df.with_columns(
85+
nw.col("arg entina").cum_count(reverse=reverse).over("g", order_by="ban gkock")
86+
).sort("i ran")
87+
expected = {"arg entina": expected_a, "ban gkock": [1, 0, 2], "i ran": [0, 1, 2]}
88+
assert_equal_data(result, expected)

tests/expr_and_series/cum_prod_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import pytest
44

55
import narwhals.stable.v1 as nw
6+
from tests.utils import DUCKDB_VERSION
67
from tests.utils import PANDAS_VERSION
8+
from tests.utils import POLARS_VERSION
79
from tests.utils import PYARROW_VERSION
10+
from tests.utils import Constructor
811
from tests.utils import ConstructorEager
912
from tests.utils import assert_equal_data
1013

@@ -54,3 +57,53 @@ def test_cum_prod_series(
5457
reverse_cum_prod=df["a"].cum_prod(reverse=True),
5558
)
5659
assert_equal_data(result, expected)
60+
61+
62+
@pytest.mark.parametrize(
63+
("reverse", "expected_a"),
64+
[
65+
(False, [2, 2, 6]),
66+
(True, [3, 6, 3]),
67+
],
68+
)
69+
def test_lazy_cum_prod_grouped(
70+
constructor: Constructor,
71+
request: pytest.FixtureRequest,
72+
*,
73+
reverse: bool,
74+
expected_a: list[int],
75+
) -> None:
76+
if "pyarrow_table" in str(constructor):
77+
# grouped window functions not yet supported
78+
request.applymarker(pytest.mark.xfail)
79+
if "modin" in str(constructor):
80+
pytest.skip(reason="probably bugged")
81+
if "dask" in str(constructor):
82+
# https://github.com/dask/dask/issues/11806
83+
request.applymarker(pytest.mark.xfail)
84+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
85+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
86+
):
87+
pytest.skip(reason="too old version")
88+
if "cudf" in str(constructor):
89+
# https://github.com/rapidsai/cudf/issues/18159
90+
request.applymarker(pytest.mark.xfail)
91+
if "sqlframe" in str(constructor):
92+
# https://github.com/eakmanrq/sqlframe/issues/348
93+
request.applymarker(pytest.mark.xfail)
94+
95+
df = nw.from_native(
96+
constructor(
97+
{
98+
"arg entina": [1, 2, 3],
99+
"ban gkock": [1, 0, 2],
100+
"i ran": [0, 1, 2],
101+
"g": [1, 1, 1],
102+
}
103+
)
104+
)
105+
result = df.with_columns(
106+
nw.col("arg entina").cum_prod(reverse=reverse).over("g", order_by="ban gkock")
107+
).sort("i ran")
108+
expected = {"arg entina": expected_a, "ban gkock": [1, 0, 2], "i ran": [0, 1, 2]}
109+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)