Skip to content

Commit 9447959

Browse files
committed
feat: Support drop(*columns: OneOrIterable[ColumnNameOrSelector])
This also simplifies the compliant-level, since resolving names is part of expansion Heavily based on what `polars` does
1 parent 616a5d5 commit 9447959

File tree

7 files changed

+77
-16
lines changed

7 files changed

+77
-16
lines changed

narwhals/_plan/_parse.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterable, Sequence
3+
import operator
4+
from collections import deque
5+
from collections.abc import Collection, Iterable, Sequence
46

57
# ruff: noqa: A002
8+
from functools import reduce
69
from itertools import chain
710
from typing import TYPE_CHECKING
811

@@ -14,6 +17,7 @@
1417
is_iterable_reject,
1518
is_selector,
1619
)
20+
from narwhals._plan.common import flatten_hash_safe
1721
from narwhals._plan.exceptions import invalid_into_expr_error, is_iterable_error
1822
from narwhals._utils import qualified_type_name
1923
from narwhals.dependencies import get_polars
@@ -27,6 +31,7 @@
2731

2832
from narwhals._plan.expr import Expr
2933
from narwhals._plan.expressions import ExprIR, SelectorIR
34+
from narwhals._plan.selectors import Selector
3035
from narwhals._plan.typing import (
3136
ColumnNameOrSelector,
3237
IntoExpr,
@@ -129,19 +134,55 @@ def parse_into_expr_ir(
129134
return expr._ir
130135

131136

132-
def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR:
137+
def parse_into_selector_ir(
138+
input: ColumnNameOrSelector | Expr, /, *, require_all: bool = True
139+
) -> SelectorIR:
140+
return _parse_into_selector(input, require_all=require_all)._ir
141+
142+
143+
def _parse_into_selector(
144+
input: ColumnNameOrSelector | Expr, /, *, require_all: bool = True
145+
) -> Selector:
133146
if is_selector(input):
134147
selector = input
135148
elif isinstance(input, str):
136-
from narwhals._plan import selectors as cs
149+
import narwhals._plan.selectors as cs
137150

138-
selector = cs.by_name(input)
151+
selector = cs.by_name(input, require_all=require_all)
139152
elif is_expr(input):
140153
selector = input.meta.as_selector()
141154
else:
142155
msg = f"cannot turn {qualified_type_name(input)!r} into a selector"
143156
raise TypeError(msg)
144-
return selector._ir
157+
return selector
158+
159+
160+
def parse_into_combined_selector_ir(
161+
*inputs: OneOrIterable[ColumnNameOrSelector], require_all: bool = True
162+
) -> SelectorIR:
163+
import narwhals._plan.selectors as cs
164+
165+
flat = tuple(flatten_hash_safe(inputs))
166+
selectors = deque["Selector"]()
167+
if names := tuple(el for el in flat if isinstance(el, str)):
168+
selector = cs.by_name(names, require_all=require_all)
169+
if len(names) == len(flat):
170+
return selector._ir
171+
selectors.append(selector)
172+
selectors.extend(_parse_into_selector(el) for el in flat if not isinstance(el, str))
173+
return _any_of(selectors)._ir
174+
175+
176+
def _any_of(selectors: Iterable[Selector], /) -> Selector:
177+
import narwhals._plan.selectors as cs
178+
179+
if isinstance(selectors, Collection):
180+
if not selectors:
181+
return cs.empty()
182+
if len(selectors) == 1:
183+
return next(iter(selectors)) # type: ignore[no-any-return]
184+
s: Selector = reduce(operator.or_, selectors)
185+
return s
145186

146187

147188
def parse_into_seq_of_expr_ir(

narwhals/_plan/arrow/dataframe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from narwhals._plan.compliant.typing import namespace
1818
from narwhals._plan.expressions import NamedIR
1919
from narwhals._plan.typing import Seq
20-
from narwhals._utils import Implementation, Version, parse_columns_to_drop
20+
from narwhals._utils import Implementation, Version
2121
from narwhals.schema import Schema
2222

2323
if TYPE_CHECKING:
@@ -106,9 +106,8 @@ def get_column(self, name: str) -> Series:
106106
chunked = self.native.column(name)
107107
return Series.from_native(chunked, name, version=self.version)
108108

109-
def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self:
110-
to_drop = parse_columns_to_drop(self, columns, strict=strict)
111-
return self._with_native(self.native.drop(to_drop))
109+
def drop(self, columns: Sequence[str]) -> Self:
110+
return self._with_native(self.native.drop(list(columns)))
112111

113112
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
114113
if subset is None:

narwhals/_plan/common.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222

2323
from narwhals._plan.compliant.series import CompliantSeries
2424
from narwhals._plan.series import Series
25-
from narwhals._plan.typing import DTypeT, NonNestedDTypeT, OneOrIterable, Seq
25+
from narwhals._plan.typing import (
26+
ColumnNameOrSelector,
27+
DTypeT,
28+
NonNestedDTypeT,
29+
OneOrIterable,
30+
Seq,
31+
)
2632
from narwhals._utils import _StoresColumns
2733
from narwhals.typing import NonNestedDType, NonNestedLiteral
2834

@@ -85,8 +91,12 @@ def flatten_hash_safe(
8591
iterable: Iterable[OneOrIterable[CompliantSeries]], /
8692
) -> Iterator[CompliantSeries]: ...
8793
@overload
94+
def flatten_hash_safe(
95+
iterable: Iterable[OneOrIterable[ColumnNameOrSelector]], /
96+
) -> Iterator[ColumnNameOrSelector]: ...
97+
@overload
8898
def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]: ...
89-
def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]:
99+
def flatten_hash_safe(iterable: Iterable[OneOrIterable[Any]], /) -> Iterator[Any]:
90100
"""Fully unwrap all levels of nesting.
91101
92102
Aiming to reduce the chances of passing an unhashable argument.
@@ -95,7 +105,7 @@ def flatten_hash_safe(iterable: Iterable[OneOrIterable[T]], /) -> Iterator[T]:
95105
if isinstance(element, Iterable) and not is_iterable_reject(element):
96106
yield from flatten_hash_safe(element)
97107
else:
98-
yield element # type: ignore[misc]
108+
yield element
99109

100110

101111
def _not_one_or_iterable_str_error(obj: Any, /) -> TypeError: # pragma: no cover

narwhals/_plan/compliant/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def native(self) -> NativeFrameT_co: ...
5555
def to_narwhals(self) -> BaseFrame[NativeFrameT_co]: ...
5656
@property
5757
def columns(self) -> list[str]: ...
58-
def drop(self, columns: Sequence[str], *, strict: bool = True) -> Self: ...
58+
def drop(self, columns: Sequence[str]) -> Self: ...
5959
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
6060
# Shouldn't *need* to be `NamedIR`, but current impl depends on a name being passed around
6161
def filter(self, predicate: NamedIR, /) -> Self: ...

narwhals/_plan/dataframe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ def sort(
110110
named_irs, _ = prepare_projection(sort, schema=self)
111111
return self._with_compliant(self._compliant.sort(named_irs, opts))
112112

113-
def drop(self, *columns: str, strict: bool = True) -> Self:
114-
return self._with_compliant(self._compliant.drop(columns, strict=strict))
113+
def drop(
114+
self, *columns: OneOrIterable[ColumnNameOrSelector], strict: bool = True
115+
) -> Self:
116+
s_ir = _parse.parse_into_combined_selector_ir(*columns, require_all=strict)
117+
names = expand_selector_irs_names((s_ir,), schema=self)
118+
return self._with_compliant(self._compliant.drop(names))
115119

116120
def drop_nulls(
117121
self, subset: OneOrIterable[ColumnNameOrSelector] | None = None

narwhals/_plan/expressions/selectors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ class ByDType(DTypeSelector, dtype=DType):
257257
dtypes: frozenset[DType | type[DType]]
258258

259259
def __repr__(self) -> str:
260+
if not self.dtypes:
261+
return "ncs.empty()"
260262
return f"ncs.by_dtype([{', '.join(sorted(map(repr, self.dtypes)))}])"
261263

262264
def _matches(self, dtype: DType | type[DType]) -> bool:

narwhals/_plan/selectors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"categorical",
3535
"datetime",
3636
"duration",
37+
"empty",
3738
"enum",
3839
"first",
3940
"float",
@@ -180,7 +181,7 @@ def by_dtype(*dtypes: OneOrIterable[DType | type[DType]]) -> Selector:
180181
it = iter(selectors)
181182
if first := next(it, None):
182183
return reduce(operator.or_, it, first)
183-
return s_ir.ByDType.empty().to_selector_ir().to_narwhals()
184+
return empty()
184185

185186

186187
def by_index(*indices: OneOrIterable[int], require_all: bool = True) -> Selector:
@@ -222,6 +223,10 @@ def duration(time_unit: OneOrIterable[TimeUnit] | None = None) -> Selector:
222223
return s_ir.Duration.from_time_unit(time_unit).to_selector_ir().to_narwhals()
223224

224225

226+
def empty() -> Selector:
227+
return s_ir.ByDType.empty().to_selector_ir().to_narwhals()
228+
229+
225230
def enum() -> Selector:
226231
return s_ir.Enum().to_selector_ir().to_narwhals()
227232

0 commit comments

Comments
 (0)