Skip to content

Commit 2bcc6bb

Browse files
authored
feat: Add cum max and cum_min for DuckDB (#2278)
1 parent daaeaf9 commit 2bcc6bb

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

narwhals/_duckdb/expr.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
7979
backend_version=self._backend_version, version=self._version
8080
)
8181

82+
def _cum_window_func(
83+
self, *, reverse: bool, func_name: Literal["sum", "max", "min"]
84+
) -> WindowFunction:
85+
def func(window_inputs: WindowInputs) -> duckdb.Expression:
86+
order_by_sql = generate_order_by_sql(
87+
*window_inputs.order_by, ascending=not reverse
88+
)
89+
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
90+
sql = (
91+
f"{func_name} ({window_inputs.expr}) over ({partition_by_sql} {order_by_sql} "
92+
"rows between unbounded preceding and current row)"
93+
)
94+
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
95+
96+
return func
97+
8298
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
8399
if kind is ExprKind.LITERAL:
84100
return self
@@ -486,18 +502,19 @@ def func(window_inputs: WindowInputs) -> duckdb.Expression:
486502
return self._with_window_function(func)
487503

488504
def cum_sum(self, *, reverse: bool) -> Self:
489-
def func(window_inputs: WindowInputs) -> duckdb.Expression:
490-
order_by_sql = generate_order_by_sql(
491-
*window_inputs.order_by, ascending=not reverse
492-
)
493-
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
494-
sql = (
495-
f"sum ({window_inputs.expr}) over ({partition_by_sql} {order_by_sql} "
496-
"rows between unbounded preceding and current row)"
497-
)
498-
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
505+
return self._with_window_function(
506+
self._cum_window_func(reverse=reverse, func_name="sum")
507+
)
499508

500-
return self._with_window_function(func)
509+
def cum_max(self, *, reverse: bool) -> Self:
510+
return self._with_window_function(
511+
self._cum_window_func(reverse=reverse, func_name="max")
512+
)
513+
514+
def cum_min(self, *, reverse: bool) -> Self:
515+
return self._with_window_function(
516+
self._cum_window_func(reverse=reverse, func_name="min")
517+
)
501518

502519
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
503520
if center:
@@ -564,6 +581,4 @@ def struct(self: Self) -> DuckDBExprStructNamespace:
564581
unique = not_implemented()
565582
is_unique = not_implemented()
566583
cum_count = not_implemented()
567-
cum_min = not_implemented()
568-
cum_max = not_implemented()
569584
cum_prod = not_implemented()

tests/expr_and_series/cum_max_test.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
import narwhals.stable.v1 as nw
6+
from tests.utils import DUCKDB_VERSION
67
from tests.utils import PANDAS_VERSION
78
from tests.utils import POLARS_VERSION
89
from tests.utils import PYARROW_VERSION
@@ -53,9 +54,6 @@ def test_lazy_cum_max_grouped(
5354
reverse: bool,
5455
expected_a: list[int],
5556
) -> None:
56-
if "duckdb" in str(constructor):
57-
# no window function support yet in duckdb
58-
request.applymarker(pytest.mark.xfail)
5957
if "pyarrow_table" in str(constructor):
6058
# grouped window functions not yet supported
6159
request.applymarker(pytest.mark.xfail)
@@ -64,7 +62,9 @@ def test_lazy_cum_max_grouped(
6462
if "dask" in str(constructor):
6563
# https://github.com/dask/dask/issues/11806
6664
request.applymarker(pytest.mark.xfail)
67-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
65+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
66+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
67+
):
6868
pytest.skip(reason="too old version")
6969
if "cudf" in str(constructor):
7070
# https://github.com/rapidsai/cudf/issues/18159
@@ -101,9 +101,6 @@ def test_lazy_cum_max_ordered_by_nulls(
101101
reverse: bool,
102102
expected_a: list[int],
103103
) -> None:
104-
if "duckdb" in str(constructor):
105-
# no window function support yet in duckdb
106-
request.applymarker(pytest.mark.xfail)
107104
if "pyarrow_table" in str(constructor):
108105
# grouped window functions not yet supported
109106
request.applymarker(pytest.mark.xfail)
@@ -112,7 +109,9 @@ def test_lazy_cum_max_ordered_by_nulls(
112109
if "dask" in str(constructor):
113110
# https://github.com/dask/dask/issues/11806
114111
request.applymarker(pytest.mark.xfail)
115-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
112+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
113+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
114+
):
116115
pytest.skip(reason="too old version")
117116
if "cudf" in str(constructor):
118117
# https://github.com/rapidsai/cudf/issues/18159
@@ -153,15 +152,14 @@ def test_lazy_cum_max_ungrouped(
153152
reverse: bool,
154153
expected_a: list[int],
155154
) -> None:
156-
if "duckdb" in str(constructor):
157-
# no window function support yet in duckdb
158-
request.applymarker(pytest.mark.xfail)
159155
if "dask" in str(constructor) and reverse:
160156
# https://github.com/dask/dask/issues/11802
161157
request.applymarker(pytest.mark.xfail)
162158
if "modin" in str(constructor):
163159
pytest.skip(reason="probably bugged")
164-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
160+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
161+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
162+
):
165163
pytest.skip(reason="too old version")
166164
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
167165
request.applymarker(pytest.mark.xfail)
@@ -196,15 +194,14 @@ def test_lazy_cum_max_ungrouped_ordered_by_nulls(
196194
reverse: bool,
197195
expected_a: list[int],
198196
) -> None:
199-
if "duckdb" in str(constructor):
200-
# no window function support yet in duckdb
201-
request.applymarker(pytest.mark.xfail)
202197
if "dask" in str(constructor):
203198
# https://github.com/dask/dask/issues/11806
204199
request.applymarker(pytest.mark.xfail)
205200
if "modin" in str(constructor):
206201
pytest.skip(reason="probably bugged")
207-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
202+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
203+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
204+
):
208205
pytest.skip(reason="too old version")
209206
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
210207
request.applymarker(pytest.mark.xfail)

tests/expr_and_series/cum_min_test.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
import narwhals.stable.v1 as nw
6+
from tests.utils import DUCKDB_VERSION
67
from tests.utils import PANDAS_VERSION
78
from tests.utils import POLARS_VERSION
89
from tests.utils import PYARROW_VERSION
@@ -53,9 +54,6 @@ def test_lazy_cum_min_grouped(
5354
reverse: bool,
5455
expected_a: list[int],
5556
) -> None:
56-
if "duckdb" in str(constructor):
57-
# no window function support yet in duckdb
58-
request.applymarker(pytest.mark.xfail)
5957
if "pyarrow_table" in str(constructor):
6058
# grouped window functions not yet supported
6159
request.applymarker(pytest.mark.xfail)
@@ -64,7 +62,9 @@ def test_lazy_cum_min_grouped(
6462
if "dask" in str(constructor):
6563
# https://github.com/dask/dask/issues/11806
6664
request.applymarker(pytest.mark.xfail)
67-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
65+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
66+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
67+
):
6868
pytest.skip(reason="too old version")
6969
if "cudf" in str(constructor):
7070
# https://github.com/rapidsai/cudf/issues/18159
@@ -101,9 +101,6 @@ def test_lazy_cum_min_ordered_by_nulls(
101101
reverse: bool,
102102
expected_a: list[int],
103103
) -> None:
104-
if "duckdb" in str(constructor):
105-
# no window function support yet in duckdb
106-
request.applymarker(pytest.mark.xfail)
107104
if "pyarrow_table" in str(constructor):
108105
# grouped window functions not yet supported
109106
request.applymarker(pytest.mark.xfail)
@@ -112,7 +109,9 @@ def test_lazy_cum_min_ordered_by_nulls(
112109
if "dask" in str(constructor):
113110
# https://github.com/dask/dask/issues/11806
114111
request.applymarker(pytest.mark.xfail)
115-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
112+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
113+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
114+
):
116115
pytest.skip(reason="too old version")
117116
if "cudf" in str(constructor):
118117
# https://github.com/rapidsai/cudf/issues/18159
@@ -153,15 +152,14 @@ def test_lazy_cum_min_ungrouped(
153152
reverse: bool,
154153
expected_a: list[int],
155154
) -> None:
156-
if "duckdb" in str(constructor):
157-
# no window function support yet in duckdb
158-
request.applymarker(pytest.mark.xfail)
159155
if "dask" in str(constructor) and reverse:
160156
# https://github.com/dask/dask/issues/11802
161157
request.applymarker(pytest.mark.xfail)
162158
if "modin" in str(constructor):
163159
pytest.skip(reason="probably bugged")
164-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
160+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
161+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
162+
):
165163
pytest.skip(reason="too old version")
166164
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
167165
request.applymarker(pytest.mark.xfail)
@@ -196,15 +194,14 @@ def test_lazy_cum_min_ungrouped_ordered_by_nulls(
196194
reverse: bool,
197195
expected_a: list[int],
198196
) -> None:
199-
if "duckdb" in str(constructor):
200-
# no window function support yet in duckdb
201-
request.applymarker(pytest.mark.xfail)
202197
if "dask" in str(constructor):
203198
# https://github.com/dask/dask/issues/11806
204199
request.applymarker(pytest.mark.xfail)
205200
if "modin" in str(constructor):
206201
pytest.skip(reason="probably bugged")
207-
if "polars" in str(constructor) and POLARS_VERSION < (1, 9):
202+
if ("polars" in str(constructor) and POLARS_VERSION < (1, 9)) or (
203+
"duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3)
204+
):
208205
pytest.skip(reason="too old version")
209206
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
210207
request.applymarker(pytest.mark.xfail)

0 commit comments

Comments
 (0)