Skip to content

Commit 4f4f713

Browse files
authored
enh: Support keep={'first', 'last'} in {DataFrame,LazyFrame}.unique (#3118)
1 parent fb7670b commit 4f4f713

File tree

13 files changed

+250
-89
lines changed

13 files changed

+250
-89
lines changed

docs/api-reference/typing.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ Narwhals comes fully statically typed. In addition to `nw.DataFrame`, `nw.Expr`,
3535
- RankMethod
3636
- RollingInterpolationMethod
3737
- UniqueKeepStrategy
38-
- LazyUniqueKeepStrategy
3938
show_source: false
4039
show_bases: false
4140

narwhals/_arrow/dataframe.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ def unique(
700700
*,
701701
keep: UniqueKeepStrategy,
702702
maintain_order: bool | None = None,
703+
order_by: Sequence[str] | None,
703704
) -> Self:
704705
# The param `maintain_order` is only here for compatibility with the Polars API
705706
# and has no effect on the output.
@@ -714,14 +715,30 @@ def unique(
714715

715716
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
716717
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
718+
if order_by and maintain_order:
719+
idx_token = generate_temporary_column_name(
720+
n_bytes=8, columns=[*self.columns, col_token]
721+
)
722+
df = (
723+
self.with_row_index(idx_token, order_by=None)
724+
.sort(*order_by, nulls_last=False, descending=False)
725+
.unique(subset=subset, keep=keep, maintain_order=False, order_by=None)
726+
)
727+
return df.sort(idx_token, descending=False, nulls_last=False).drop(
728+
[idx_token], strict=False
729+
)
730+
if order_by:
731+
native = self.sort(*order_by, nulls_last=False, descending=False).native
732+
else:
733+
native = self.native
717734
keep_idx_native = (
718-
self.native.append_column(col_token, pa.array(np.arange(len(self))))
735+
native.append_column(col_token, pa.array(np.arange(len(self))))
719736
.group_by(subset)
720737
.aggregate([(col_token, agg_func)])
721738
.column(f"{col_token}_{agg_func}")
722739
)
723740
return self._with_native(
724-
self.native.take(keep_idx_native), validate_column_names=False
741+
native.take(keep_idx_native), validate_column_names=False
725742
)
726743

727744
keep_idx = self.simple_select(*subset).is_unique()

narwhals/_compliant/dataframe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
AsofJoinStrategy,
6262
IntoSchema,
6363
JoinStrategy,
64-
LazyUniqueKeepStrategy,
6564
MultiColSelector,
6665
MultiIndexSelector,
6766
PivotAgg,
@@ -155,7 +154,11 @@ def sort(
155154
) -> Self: ...
156155
def tail(self, n: int) -> Self: ...
157156
def unique(
158-
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
157+
self,
158+
subset: Sequence[str] | None,
159+
*,
160+
keep: UniqueKeepStrategy,
161+
order_by: Sequence[str] | None,
159162
) -> Self: ...
160163
def unpivot(
161164
self,
@@ -265,6 +268,7 @@ def unique(
265268
*,
266269
keep: UniqueKeepStrategy,
267270
maintain_order: bool | None = None,
271+
order_by: Sequence[str] | None,
268272
) -> Self: ...
269273
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
270274
@overload

narwhals/_dask/dataframe.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from narwhals.dataframe import LazyFrame
4040
from narwhals.dtypes import DType
4141
from narwhals.exceptions import ColumnNotFoundError
42-
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
42+
from narwhals.typing import AsofJoinStrategy, JoinStrategy, UniqueKeepStrategy
4343

4444
Incomplete: TypeAlias = "Any"
4545
"""Using `_pandas_like` utils with `_dask`.
@@ -237,7 +237,11 @@ def head(self, n: int) -> Self:
237237
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
238238

239239
def unique(
240-
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
240+
self,
241+
subset: Sequence[str] | None,
242+
*,
243+
keep: UniqueKeepStrategy,
244+
order_by: Sequence[str] | None,
241245
) -> Self:
242246
if subset and (error := self._check_columns_exist(subset)):
243247
raise error
@@ -250,7 +254,11 @@ def unique(
250254
result = self.native.merge(unique, on=subset, how="inner")
251255
else:
252256
mapped_keep = {"any": "first"}.get(keep, keep)
253-
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
257+
if order_by:
258+
native = self.sort(*order_by, descending=False, nulls_last=False).native
259+
else:
260+
native = self.native
261+
result = native.drop_duplicates(subset=subset, keep=mapped_keep)
254262
return self._with_native(result)
255263

256264
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:

narwhals/_dask/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
112112
# No aggregation provided
113113
return (
114114
self.compliant.simple_select(*self._keys)
115-
.unique(self._keys, keep="any")
115+
.unique(self._keys, keep="any", order_by=None)
116116
.rename(dict(zip(self._keys, self._output_key_names)))
117117
)
118118

narwhals/_duckdb/dataframe.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from narwhals.dataframe import LazyFrame
5656
from narwhals.dtypes import DType
5757
from narwhals.stable.v1 import DataFrame as DataFrameV1
58-
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
58+
from narwhals.typing import AsofJoinStrategy, JoinStrategy, UniqueKeepStrategy
5959

6060

6161
class DuckDBLazyFrame(
@@ -188,7 +188,7 @@ def select(self, *exprs: DuckDBExpr) -> Self:
188188

189189
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
190190
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
191-
selection = (name for name in self.columns if name not in columns_to_drop)
191+
selection = [col(name) for name in self.columns if name not in columns_to_drop]
192192
return self._with_native(self.native.select(*selection))
193193

194194
def lazy(self, backend: None = None, **_: None) -> Self:
@@ -380,25 +380,43 @@ def collect_schema(self) -> dict[str, DType]:
380380
return self.schema
381381

382382
def unique(
383-
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
383+
self,
384+
subset: Sequence[str] | None,
385+
*,
386+
keep: UniqueKeepStrategy,
387+
order_by: Sequence[str] | None,
384388
) -> Self:
385-
if subset_ := subset if keep == "any" else (subset or self.columns):
386-
# Sanitise input
387-
if error := self._check_columns_exist(subset_):
388-
raise error
389-
idx_name = generate_temporary_column_name(8, self.columns)
390-
count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
391-
name = count_name if keep == "none" else idx_name
392-
idx_expr = window_expression(F("row_number"), subset_).alias(idx_name)
393-
count_expr = window_expression(
394-
F("count", StarExpression()), subset_, ()
395-
).alias(count_name)
396-
return self._with_native(
397-
self.native.select(StarExpression(), idx_expr, count_expr)
398-
.filter(col(name) == lit(1))
399-
.select(StarExpression(exclude=[count_name, idx_name]))
389+
subset_ = subset or self.columns
390+
if error := self._check_columns_exist(subset_):
391+
raise error
392+
tmp_name = generate_temporary_column_name(8, self.columns)
393+
if order_by and keep == "last":
394+
descending = [True] * len(order_by)
395+
nulls_last = [True] * len(order_by)
396+
else:
397+
descending = None
398+
nulls_last = None
399+
if keep == "none":
400+
expr = window_expression(
401+
F("count", StarExpression()),
402+
subset_,
403+
order_by or (),
404+
descending=descending,
405+
nulls_last=nulls_last,
406+
)
407+
else:
408+
expr = window_expression(
409+
F("row_number"),
410+
subset_,
411+
order_by or (),
412+
descending=descending,
413+
nulls_last=nulls_last,
414+
)
415+
return self._with_native(
416+
self.native.select(StarExpression(), expr.alias(tmp_name)).filter(
417+
col(tmp_name) == lit(1)
400418
)
401-
return self._with_native(self.native.unique(join_column_names(*self.columns)))
419+
).drop([tmp_name], strict=False)
402420

403421
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
404422
if isinstance(descending, bool):

narwhals/_ibis/dataframe.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,25 @@
22

33
import operator
44
from io import BytesIO
5-
from typing import TYPE_CHECKING, Any, Literal, cast
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import ibis
88
import ibis.expr.types as ir
99

10-
from narwhals._ibis.utils import evaluate_exprs, native_to_narwhals_dtype
10+
from narwhals._ibis.expr import IbisExpr
11+
from narwhals._ibis.utils import evaluate_exprs, lit, native_to_narwhals_dtype
1112
from narwhals._sql.dataframe import SQLLazyFrame
1213
from narwhals._utils import (
1314
Implementation,
1415
ValidateBackendVersion,
1516
Version,
17+
generate_temporary_column_name,
1618
not_implemented,
1719
parse_columns_to_drop,
1820
to_pyarrow_table,
1921
zip_strict,
2022
)
21-
from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError
23+
from narwhals.exceptions import InvalidOperationError
2224

2325
if TYPE_CHECKING:
2426
from collections.abc import Iterable, Iterator, Mapping, Sequence
@@ -31,7 +33,6 @@
3133
from typing_extensions import Self, TypeAlias, TypeIs
3234

3335
from narwhals._compliant.typing import CompliantDataFrameAny
34-
from narwhals._ibis.expr import IbisExpr
3536
from narwhals._ibis.group_by import IbisGroupBy
3637
from narwhals._ibis.namespace import IbisNamespace
3738
from narwhals._ibis.series import IbisInterchangeSeries
@@ -40,7 +41,7 @@
4041
from narwhals.dataframe import LazyFrame
4142
from narwhals.dtypes import DType
4243
from narwhals.stable.v1 import DataFrame as DataFrameV1
43-
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
44+
from narwhals.typing import AsofJoinStrategy, JoinStrategy, UniqueKeepStrategy
4445

4546
JoinPredicates: TypeAlias = "Sequence[ir.BooleanColumn] | Sequence[str]"
4647

@@ -320,21 +321,33 @@ def collect_schema(self) -> dict[str, DType]:
320321
}
321322

322323
def unique(
323-
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
324+
self,
325+
subset: Sequence[str] | None,
326+
*,
327+
keep: UniqueKeepStrategy,
328+
order_by: Sequence[str] | None,
324329
) -> Self:
325-
if subset_ := subset if keep == "any" else (subset or self.columns):
326-
# Sanitise input
327-
if any(x not in self.columns for x in subset_):
328-
msg = f"Columns {set(subset_).difference(self.columns)} not found in {self.columns}."
329-
raise ColumnNotFoundError(msg)
330-
331-
mapped_keep: dict[str, Literal["first"] | None] = {
332-
"any": "first",
333-
"none": None,
334-
}
335-
to_keep = mapped_keep[keep]
336-
return self._with_native(self.native.distinct(on=subset_, keep=to_keep))
337-
return self._with_native(self.native.distinct(on=subset))
330+
subset_ = subset or self.columns
331+
if error := self._check_columns_exist(subset_):
332+
raise error
333+
tmp_name = generate_temporary_column_name(8, self.columns)
334+
if order_by and keep == "last":
335+
order_by_ = IbisExpr._sort(*order_by, descending=True, nulls_last=True)
336+
elif order_by:
337+
order_by_ = IbisExpr._sort(*order_by, descending=False, nulls_last=False)
338+
else:
339+
order_by_ = lit(1)
340+
window = ibis.window(group_by=subset_, order_by=order_by_)
341+
if keep == "none":
342+
expr = self.native.count().over(window)
343+
else:
344+
expr = ibis.row_number().over(window) + lit(1)
345+
df = (
346+
self.native.mutate(**{tmp_name: expr})
347+
.filter(ibis._[tmp_name] == lit(1))
348+
.drop(tmp_name)
349+
)
350+
return self._with_native(df)
338351

339352
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
340353
from narwhals._ibis.expr import IbisExpr

narwhals/_pandas_like/dataframe.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -761,16 +761,29 @@ def unique(
761761
*,
762762
keep: UniqueKeepStrategy,
763763
maintain_order: bool | None = None,
764+
order_by: Sequence[str] | None,
764765
) -> Self:
765766
# The param `maintain_order` is only here for compatibility with the Polars API
766767
# and has no effect on the output.
767768
mapped_keep = {"none": False, "any": "first"}.get(keep, keep)
768769
if subset and (error := self._check_columns_exist(subset)):
769770
raise error
770-
return self._with_native(
771-
self.native.drop_duplicates(subset=subset, keep=mapped_keep),
772-
validate_column_names=False,
773-
)
771+
if order_by and maintain_order:
772+
token = generate_temporary_column_name(8, self.columns)
773+
res = (
774+
self.with_row_index(token, order_by=None)
775+
.sort(*order_by, nulls_last=False, descending=False)
776+
.native.drop_duplicates(subset or self.columns, keep=mapped_keep)
777+
.sort_values(token)
778+
)
779+
res.drop(columns=token, inplace=True) # noqa: PD002
780+
elif order_by:
781+
res = self.sort(
782+
*order_by, nulls_last=False, descending=False
783+
).native.drop_duplicates(subset or self.columns, keep=mapped_keep)
784+
else:
785+
res = self.native.drop_duplicates(subset or self.columns, keep=mapped_keep)
786+
return self._with_native(res, validate_column_names=False)
774787

775788
# --- lazy-only ---
776789
def lazy(

narwhals/_polars/dataframe.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
Implementation,
1717
_into_arrow_table,
1818
convert_str_slice_to_int_slice,
19+
generate_temporary_column_name,
1920
is_compliant_series,
2021
is_index_selector,
2122
is_range,
@@ -53,6 +54,7 @@
5354
MultiIndexSelector,
5455
PivotAgg,
5556
SingleIndexSelector,
57+
UniqueKeepStrategy,
5658
_2DArray,
5759
)
5860

@@ -89,7 +91,6 @@
8991
"tail",
9092
"to_arrow",
9193
"to_pandas",
92-
"unique",
9394
"with_columns",
9495
"write_csv",
9596
"write_parquet",
@@ -110,7 +111,6 @@ class PolarsBaseFrame(Generic[NativePolarsFrame]):
110111
select: Method[Self]
111112
sort: Method[Self]
112113
tail: Method[Self]
113-
unique: Method[Self]
114114
with_columns: Method[Self]
115115

116116
_native_frame: NativePolarsFrame
@@ -175,6 +175,31 @@ def simple_select(self, *column_names: str) -> Self:
175175
def aggregate(self, *exprs: Any) -> Self:
176176
return self.select(*exprs)
177177

178+
def unique(
179+
self,
180+
subset: Sequence[str] | None,
181+
*,
182+
keep: UniqueKeepStrategy,
183+
maintain_order: bool | None = None,
184+
order_by: Sequence[str] | None = None,
185+
) -> Self:
186+
if order_by and maintain_order:
187+
token = generate_temporary_column_name(8, self.columns)
188+
res = (
189+
self.native.with_row_index(token)
190+
.sort(order_by, nulls_last=False)
191+
.unique(subset or self.columns, keep=keep)
192+
.sort(token)
193+
.drop(token)
194+
)
195+
elif order_by:
196+
res = self.native.sort(order_by).unique(subset, keep=keep)
197+
else:
198+
res = self.native.unique(
199+
subset, keep=keep, maintain_order=maintain_order or False
200+
)
201+
return self._with_native(res)
202+
178203
@property
179204
def schema(self) -> dict[str, DType]:
180205
return self.collect_schema()

0 commit comments

Comments
 (0)