Skip to content

Commit 9b0bc75

Browse files
authored
enh: support unique(keep='none') for pyspark/sqlframe (#2338)
1 parent 189dff4 commit 9b0bc75

File tree

3 files changed

+75
-41
lines changed

3 files changed

+75
-41
lines changed

narwhals/_duckdb/dataframe.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,25 +350,26 @@ def collect_schema(self: Self) -> dict[str, DType]:
350350
def unique(
351351
self: Self, subset: Sequence[str] | None, *, keep: Literal["any", "none"]
352352
) -> Self:
353-
if subset is not None:
354-
rel = self.native
353+
subset_ = subset if keep == "any" else (subset or self.columns)
354+
if subset_:
355355
# Sanitise input
356-
if any(x not in rel.columns for x in subset):
357-
msg = f"Columns {set(subset).difference(rel.columns)} not found in {rel.columns}."
356+
if any(x not in self.columns for x in subset_):
357+
msg = f"Columns {set(subset_).difference(self.columns)} not found in {self.columns}."
358358
raise ColumnNotFoundError(msg)
359-
idx_name = generate_temporary_column_name(8, rel.columns)
360-
count_name = generate_temporary_column_name(8, [*rel.columns, idx_name])
361-
if keep == "none":
362-
keep_condition = col(count_name) == lit(1)
363-
else:
364-
keep_condition = col(idx_name) == lit(1)
365-
partition_by_sql = generate_partition_by_sql(*subset)
359+
idx_name = generate_temporary_column_name(8, self.columns)
360+
count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
361+
partition_by_sql = generate_partition_by_sql(*(subset_))
362+
rel = self.native # noqa: F841
366363
query = f"""
367364
select *,
368365
row_number() over ({partition_by_sql}) as "{idx_name}",
369366
count(*) over ({partition_by_sql}) as "{count_name}"
370367
from rel
371368
""" # noqa: S608
369+
if keep == "none":
370+
keep_condition = col(count_name) == lit(1)
371+
else:
372+
keep_condition = col(idx_name) == lit(1)
372373
return self._with_native(
373374
duckdb.sql(query)
374375
.filter(keep_condition)
@@ -394,8 +395,7 @@ def sort(
394395
return self._with_native(self.native.sort(*it))
395396

396397
def drop_nulls(self: Self, subset: Sequence[str] | None) -> Self:
397-
rel = self.native
398-
subset_ = subset if subset is not None else rel.columns
398+
subset_ = subset if subset is not None else self.columns
399399
keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
400400
return self._with_native(self.native.filter(keep_condition))
401401

narwhals/_spark_like/dataframe.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from narwhals.utils import Implementation
2222
from narwhals.utils import check_column_exists
2323
from narwhals.utils import find_stacklevel
24+
from narwhals.utils import generate_temporary_column_name
2425
from narwhals.utils import import_dtypes_module
2526
from narwhals.utils import is_spark_like_dataframe
2627
from narwhals.utils import not_implemented
@@ -335,11 +336,17 @@ def unique(
335336
*,
336337
keep: Literal["any", "none"],
337338
) -> Self:
338-
if keep != "any":
339-
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
340-
raise ValueError(msg)
341339
check_column_exists(self.columns, subset)
342340
subset = list(subset) if subset else None
341+
if keep == "none":
342+
tmp = generate_temporary_column_name(8, self.columns)
343+
window = self._Window().partitionBy(subset or self.columns)
344+
df = (
345+
self.native.withColumn(tmp, self._F.count("*").over(window))
346+
.filter(self._F.col(tmp) == 1)
347+
.drop(tmp)
348+
)
349+
return self._with_native(df)
343350
return self._with_native(self.native.dropDuplicates(subset=subset))
344351

345352
def join(

tests/frame/unique_test.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

3-
from contextlib import nullcontext as does_not_raise
4-
from typing import Any
3+
from typing import Literal
54

65
import pytest
76

@@ -10,6 +9,7 @@
109
import narwhals as nw
1110
from narwhals.exceptions import ColumnNotFoundError
1211
from tests.utils import Constructor
12+
from tests.utils import ConstructorEager
1313
from tests.utils import assert_equal_data
1414

1515
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
@@ -21,44 +21,71 @@
2121
[
2222
("first", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}),
2323
("last", {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}),
24+
],
25+
)
26+
def test_unique_eager(
27+
constructor_eager: ConstructorEager,
28+
subset: str | list[str] | None,
29+
keep: Literal["first", "last"],
30+
expected: dict[str, list[float]],
31+
) -> None:
32+
df_raw = constructor_eager(data)
33+
df = nw.from_native(df_raw)
34+
result = df.unique(subset, keep=keep).sort("z")
35+
assert_equal_data(result, expected)
36+
37+
38+
def test_unique_invalid_subset(constructor: Constructor) -> None:
39+
df_raw = constructor(data)
40+
df = nw.from_native(df_raw)
41+
with pytest.raises(ColumnNotFoundError):
42+
df.lazy().unique(["fdssfad"]).collect()
43+
44+
45+
@pytest.mark.parametrize("subset", ["b", ["b"]])
46+
@pytest.mark.parametrize(
47+
("keep", "expected"),
48+
[
2449
("any", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}),
2550
("none", {"a": [2], "b": [6], "z": [9]}),
26-
("foo", {"a": [2], "b": [6], "z": [9]}),
2751
],
2852
)
2953
def test_unique(
3054
constructor: Constructor,
3155
subset: str | list[str] | None,
32-
keep: str,
56+
keep: Literal["any", "none"],
3357
expected: dict[str, list[float]],
3458
) -> None:
3559
df_raw = constructor(data)
3660
df = nw.from_native(df_raw)
37-
if isinstance(df, nw.LazyFrame) and keep in {
38-
"first",
39-
"last",
40-
}:
41-
context: Any = pytest.raises(ValueError, match="row order")
42-
elif keep == "none" and df.implementation.is_spark_like(): # pragma: no cover
43-
context = pytest.raises(
44-
ValueError,
45-
match="`LazyFrame.unique` with PySpark backend only supports `keep='any'`.",
46-
)
47-
elif keep == "foo":
48-
context = pytest.raises(ValueError, match=": foo")
49-
else:
50-
context = does_not_raise()
51-
52-
with context:
53-
result = df.unique(subset, keep=keep).sort("z") # type: ignore[arg-type]
54-
assert_equal_data(result, expected)
61+
result = df.unique(subset, keep=keep).sort("z")
62+
assert_equal_data(result, expected)
5563

5664

57-
def test_unique_invalid_subset(constructor: Constructor) -> None:
65+
@pytest.mark.parametrize("subset", [None, ["a", "b"]])
66+
@pytest.mark.parametrize(
67+
("keep", "expected"),
68+
[
69+
("any", {"a": [1, 1, 2], "b": [3, 4, 4]}),
70+
("none", {"a": [1, 2], "b": [4, 4]}),
71+
],
72+
)
73+
def test_unique_full_subset(
74+
constructor: Constructor,
75+
subset: list[str] | None,
76+
keep: Literal["any", "none"],
77+
expected: dict[str, list[float]],
78+
) -> None:
79+
data = {"a": [1, 1, 1, 2], "b": [3, 3, 4, 4]}
5880
df_raw = constructor(data)
5981
df = nw.from_native(df_raw)
60-
with pytest.raises(ColumnNotFoundError):
61-
df.lazy().unique(["fdssfad"]).collect()
82+
result = df.unique(subset, keep=keep).sort("a", "b")
83+
assert_equal_data(result, expected)
84+
85+
86+
def test_unique_invalid_keep(constructor: Constructor) -> None:
87+
with pytest.raises(ValueError, match=r"(Got|got): cabbage"):
88+
nw.from_native(constructor(data)).unique(keep="cabbage") # type: ignore[arg-type]
6289

6390

6491
@pytest.mark.filterwarnings("ignore:.*backwards-compatibility:UserWarning")

0 commit comments

Comments
 (0)