Skip to content

Commit 7cfa873

Browse files
authored
ci: cuDF fixup (#2739)
1 parent 058ef32 commit 7cfa873

File tree

9 files changed

+26
-59
lines changed

9 files changed

+26
-59
lines changed

narwhals/_duckdb/expr.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
103103

104104
def _cum_window_func(
105105
self,
106+
func_name: Literal["sum", "max", "min", "count", "product"],
106107
*,
107108
reverse: bool,
108-
func_name: Literal["sum", "max", "min", "count", "product"],
109109
) -> DuckDBWindowFunction:
110110
def func(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
111111
return [
@@ -125,12 +125,12 @@ def func(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
125125

126126
def _rolling_window_func(
127127
self,
128-
*,
129128
func_name: Literal["sum", "mean", "std", "var"],
130-
center: bool,
131129
window_size: int,
132130
min_samples: int,
133131
ddof: int | None = None,
132+
*,
133+
center: bool,
134134
) -> DuckDBWindowFunction:
135135
supported_funcs = ["sum", "mean", "std", "var"]
136136
if center:
@@ -640,54 +640,36 @@ def func(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
640640

641641
@requires.backend_version((1, 3))
642642
def cum_sum(self, *, reverse: bool) -> Self:
643-
return self._with_window_function(
644-
self._cum_window_func(reverse=reverse, func_name="sum")
645-
)
643+
return self._with_window_function(self._cum_window_func("sum", reverse=reverse))
646644

647645
@requires.backend_version((1, 3))
648646
def cum_max(self, *, reverse: bool) -> Self:
649-
return self._with_window_function(
650-
self._cum_window_func(reverse=reverse, func_name="max")
651-
)
647+
return self._with_window_function(self._cum_window_func("max", reverse=reverse))
652648

653649
@requires.backend_version((1, 3))
654650
def cum_min(self, *, reverse: bool) -> Self:
655-
return self._with_window_function(
656-
self._cum_window_func(reverse=reverse, func_name="min")
657-
)
651+
return self._with_window_function(self._cum_window_func("min", reverse=reverse))
658652

659653
@requires.backend_version((1, 3))
660654
def cum_count(self, *, reverse: bool) -> Self:
661-
return self._with_window_function(
662-
self._cum_window_func(reverse=reverse, func_name="count")
663-
)
655+
return self._with_window_function(self._cum_window_func("count", reverse=reverse))
664656

665657
@requires.backend_version((1, 3))
666658
def cum_prod(self, *, reverse: bool) -> Self:
667659
return self._with_window_function(
668-
self._cum_window_func(reverse=reverse, func_name="product")
660+
self._cum_window_func("product", reverse=reverse)
669661
)
670662

671663
@requires.backend_version((1, 3))
672664
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
673665
return self._with_window_function(
674-
self._rolling_window_func(
675-
func_name="sum",
676-
center=center,
677-
window_size=window_size,
678-
min_samples=min_samples,
679-
)
666+
self._rolling_window_func("sum", window_size, min_samples, center=center)
680667
)
681668

682669
@requires.backend_version((1, 3))
683670
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
684671
return self._with_window_function(
685-
self._rolling_window_func(
686-
func_name="mean",
687-
center=center,
688-
window_size=window_size,
689-
min_samples=min_samples,
690-
)
672+
self._rolling_window_func("mean", window_size, min_samples, center=center)
691673
)
692674

693675
@requires.backend_version((1, 3))
@@ -696,11 +678,7 @@ def rolling_var(
696678
) -> Self:
697679
return self._with_window_function(
698680
self._rolling_window_func(
699-
func_name="var",
700-
center=center,
701-
window_size=window_size,
702-
min_samples=min_samples,
703-
ddof=ddof,
681+
"var", window_size, min_samples, ddof=ddof, center=center
704682
)
705683
)
706684

@@ -710,11 +688,7 @@ def rolling_std(
710688
) -> Self:
711689
return self._with_window_function(
712690
self._rolling_window_func(
713-
func_name="std",
714-
center=center,
715-
window_size=window_size,
716-
min_samples=min_samples,
717-
ddof=ddof,
691+
"std", window_size, min_samples, ddof=ddof, center=center
718692
)
719693
)
720694

narwhals/_pandas_like/namespace.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from narwhals._pandas_like.selectors import PandasSelectorNamespace
1717
from narwhals._pandas_like.series import PandasLikeSeries
1818
from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT
19+
from narwhals._pandas_like.utils import is_non_nullable_boolean
1920

2021
if TYPE_CHECKING:
2122
from collections.abc import Iterable, Sequence
@@ -147,7 +148,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
147148
it = (
148149
(
149150
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
150-
s if s.native.dtype == "bool" else s.fill_null(True, None, None) # noqa: FBT003
151+
s if is_non_nullable_boolean(s) else s.fill_null(True, None, None) # noqa: FBT003
151152
for s in series
152153
)
153154
if ignore_nulls
@@ -180,7 +181,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
180181
it = (
181182
(
182183
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
183-
s if s.native.dtype == "bool" else s.fill_null(False, None, None) # noqa: FBT003
184+
s if is_non_nullable_boolean(s) else s.fill_null(False, None, None) # noqa: FBT003
184185
for s in series
185186
)
186187
if ignore_nulls

narwhals/_pandas_like/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,15 @@ def check_column_names_are_unique(columns: pd.Index[str]) -> None:
603603
raise DuplicateError(msg)
604604

605605

606+
def is_non_nullable_boolean(s: PandasLikeSeries) -> bool:
607+
# cuDF booleans are nullable but the native dtype is still 'bool'.
608+
return (
609+
s._implementation
610+
in {Implementation.PANDAS, Implementation.MODIN, Implementation.DASK}
611+
and s.native.dtype == "bool"
612+
)
613+
614+
606615
class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]):
607616
@property
608617
def implementation(self) -> Implementation:

tests/expr_and_series/all_horizontal_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def test_allh(constructor: Constructor) -> None:
1818
assert_equal_data(result, expected)
1919

2020

21-
def test_anyh_ignore_nulls(constructor: Constructor) -> None:
21+
def test_all_ignore_nulls(constructor: Constructor) -> None:
2222
if "dask" in str(constructor):
2323
# Dask infers `[True, None, None, None]` as `object` dtype, and then `__or__` fails.
2424
# test it below separately

tests/expr_and_series/rolling_mean_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,6 @@ def test_rolling_mean_expr_lazy_grouped(
123123
pytest.skip()
124124
if any(x in str(constructor) for x in ("dask", "pyarrow_table")):
125125
request.applymarker(pytest.mark.xfail)
126-
if "cudf" in str(constructor) and center:
127-
# center is not implemented for offset-based windows
128-
request.applymarker(pytest.mark.xfail)
129126
if "modin" in str(constructor):
130127
# unreliable
131128
pytest.skip()

tests/expr_and_series/rolling_std_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,6 @@ def test_rolling_std_expr_lazy_grouped(
317317
pytest.skip()
318318
if any(x in str(constructor) for x in ("dask", "pyarrow_table")):
319319
request.applymarker(pytest.mark.xfail)
320-
if "cudf" in str(constructor) and center:
321-
# center is not implemented for offset-based windows
322-
request.applymarker(pytest.mark.xfail)
323320
if "modin" in str(constructor):
324321
# unreliable
325322
pytest.skip()

tests/expr_and_series/rolling_sum_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ def test_rolling_sum_expr_lazy_grouped(
131131
pytest.skip()
132132
if any(x in str(constructor) for x in ("dask", "pyarrow_table")):
133133
request.applymarker(pytest.mark.xfail)
134-
if "cudf" in str(constructor) and center:
135-
# center is not implemented for offset-based windows
136-
request.applymarker(pytest.mark.xfail)
137134
if "modin" in str(constructor):
138135
# unreliable
139136
pytest.skip()

tests/expr_and_series/rolling_var_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,6 @@ def test_rolling_var_expr_lazy_grouped(
275275
pytest.skip()
276276
if any(x in str(constructor) for x in ("dask", "pyarrow_table")):
277277
request.applymarker(pytest.mark.xfail)
278-
if "cudf" in str(constructor) and center:
279-
# center is not implemented for offset-based windows
280-
request.applymarker(pytest.mark.xfail)
281278
if "modin" in str(constructor):
282279
# unreliable
283280
pytest.skip()

tests/frame/getitem_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,7 @@ def test_slice_with_series(
314314
assert_equal_data(result, expected)
315315

316316

317-
def test_horizontal_slice_with_series(
318-
constructor_eager: ConstructorEager, request: pytest.FixtureRequest
319-
) -> None:
320-
if "cudf" in str(constructor_eager):
321-
# https://github.com/rapidsai/cudf/issues/18556
322-
request.applymarker(pytest.mark.xfail)
317+
def test_horizontal_slice_with_series(constructor_eager: ConstructorEager) -> None:
323318
data = {"a": [1, 2], "c": [0, 2], "d": ["c", "a"]}
324319
nw_df = nw.from_native(constructor_eager(data), eager_only=True)
325320
result = nw_df[nw_df["d"]]

0 commit comments

Comments
 (0)