Skip to content

Commit 003972b

Browse files
committed
test: Fully cover selectors in drop
1 parent 9447959 commit 003972b

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

narwhals/_plan/_parse.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,15 @@ def parse_into_combined_selector_ir(
173173
return _any_of(selectors)._ir
174174

175175

176-
def _any_of(selectors: Iterable[Selector], /) -> Selector:
176+
def _any_of(selectors: Collection[Selector], /) -> Selector:
177177
import narwhals._plan.selectors as cs
178178

179-
if isinstance(selectors, Collection):
180-
if not selectors:
181-
return cs.empty()
182-
if len(selectors) == 1:
183-
return next(iter(selectors)) # type: ignore[no-any-return]
184-
s: Selector = reduce(operator.or_, selectors)
179+
if not selectors:
180+
s: Selector = cs.empty()
181+
elif len(selectors) == 1:
182+
s = next(iter(selectors))
183+
else:
184+
s = reduce(operator.or_, selectors)
185185
return s
186186

187187

narwhals/_plan/dataframe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def drop_nulls(
128128
def rename(self, mapping: Mapping[str, str]) -> Self:
129129
return self._with_compliant(self._compliant.rename(mapping))
130130

131+
def collect_schema(self) -> Schema:
132+
return self.schema
133+
131134

132135
class DataFrame(
133136
BaseFrame[NativeDataFrameT_co], Generic[NativeDataFrameT_co, NativeSeriesT]

tests/plan/compliant_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import re
4+
from collections.abc import Iterable
35
from typing import TYPE_CHECKING, Any
46

57
import pytest
@@ -528,6 +530,42 @@ def test_row_is_py_literal(
528530
assert result == polars_result
529531

530532

533+
@pytest.mark.parametrize(
534+
("columns", "expected"),
535+
[
536+
("a", ["b", "c"]),
537+
(["a"], ["b", "c"]),
538+
(ncs.first(), ["b", "c"]),
539+
([ncs.first()], ["b", "c"]),
540+
(["a", "b"], ["c"]),
541+
(~ncs.last(), ["c"]),
542+
([ncs.integer() | ncs.enum()], ["c"]),
543+
([ncs.first(), "b"], ["c"]),
544+
(ncs.all(), []),
545+
([], ["a", "b", "c"]),
546+
(ncs.struct(), ["a", "b", "c"]),
547+
],
548+
)
549+
def test_drop(columns: OneOrIterable[ColumnNameOrSelector], expected: list[str]) -> None:
550+
data = {"a": [1, 3, 2], "b": [4, 4, 6], "c": [7.0, 8.0, 9.0]}
551+
df = dataframe(data)
552+
if isinstance(columns, (str, nwp.Selector, list)):
553+
assert df.drop(columns).collect_schema().names() == expected
554+
if not isinstance(columns, str) and isinstance(columns, Iterable):
555+
assert df.drop(*columns).collect_schema().names() == expected
556+
557+
558+
def test_drop_strict() -> None:
559+
data = {"a": [1, 3, 2], "b": [4, 4, 6]}
560+
df = dataframe(data)
561+
with pytest.raises(ColumnNotFoundError):
562+
df.drop("z")
563+
with pytest.raises(ColumnNotFoundError, match=re.escape("not found: ['z']")):
564+
df.drop(ncs.last(), "z")
565+
assert df.drop("z", strict=False).collect_schema().names() == ["a", "b"]
566+
assert df.drop(ncs.last(), "z", strict=False).collect_schema().names() == ["a"]
567+
568+
531569
def test_drop_nulls(data_small_dh: Data) -> None:
532570
df = dataframe(data_small_dh)
533571
expected: Data = {"d": [], "e": [], "f": [], "g": [], "h": []}

0 commit comments

Comments
 (0)