|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from contextlib import nullcontext as does_not_raise |
4 | | -from typing import Any |
| 3 | +from typing import Literal |
5 | 4 |
|
6 | 5 | import pytest |
7 | 6 |
|
|
10 | 9 | import narwhals as nw |
11 | 10 | from narwhals.exceptions import ColumnNotFoundError |
12 | 11 | from tests.utils import Constructor |
| 12 | +from tests.utils import ConstructorEager |
13 | 13 | from tests.utils import assert_equal_data |
14 | 14 |
|
15 | 15 | data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} |
|
21 | 21 | [ |
22 | 22 | ("first", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}), |
23 | 23 | ("last", {"a": [3, 2], "b": [4, 6], "z": [8.0, 9.0]}), |
| 24 | + ], |
| 25 | +) |
| 26 | +def test_unique_eager( |
| 27 | + constructor_eager: ConstructorEager, |
| 28 | + subset: str | list[str] | None, |
| 29 | + keep: Literal["first", "last"], |
| 30 | + expected: dict[str, list[float]], |
| 31 | +) -> None: |
| 32 | + df_raw = constructor_eager(data) |
| 33 | + df = nw.from_native(df_raw) |
| 34 | + result = df.unique(subset, keep=keep).sort("z") |
| 35 | + assert_equal_data(result, expected) |
| 36 | + |
| 37 | + |
| 38 | +def test_unique_invalid_subset(constructor: Constructor) -> None: |
| 39 | + df_raw = constructor(data) |
| 40 | + df = nw.from_native(df_raw) |
| 41 | + with pytest.raises(ColumnNotFoundError): |
| 42 | + df.lazy().unique(["fdssfad"]).collect() |
| 43 | + |
| 44 | + |
| 45 | +@pytest.mark.parametrize("subset", ["b", ["b"]]) |
| 46 | +@pytest.mark.parametrize( |
| 47 | + ("keep", "expected"), |
| 48 | + [ |
24 | 49 | ("any", {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}), |
25 | 50 | ("none", {"a": [2], "b": [6], "z": [9]}), |
26 | | - ("foo", {"a": [2], "b": [6], "z": [9]}), |
27 | 51 | ], |
28 | 52 | ) |
29 | 53 | def test_unique( |
30 | 54 | constructor: Constructor, |
31 | 55 | subset: str | list[str] | None, |
32 | | - keep: str, |
| 56 | + keep: Literal["any", "none"], |
33 | 57 | expected: dict[str, list[float]], |
34 | 58 | ) -> None: |
35 | 59 | df_raw = constructor(data) |
36 | 60 | df = nw.from_native(df_raw) |
37 | | - if isinstance(df, nw.LazyFrame) and keep in { |
38 | | - "first", |
39 | | - "last", |
40 | | - }: |
41 | | - context: Any = pytest.raises(ValueError, match="row order") |
42 | | - elif keep == "none" and df.implementation.is_spark_like(): # pragma: no cover |
43 | | - context = pytest.raises( |
44 | | - ValueError, |
45 | | - match="`LazyFrame.unique` with PySpark backend only supports `keep='any'`.", |
46 | | - ) |
47 | | - elif keep == "foo": |
48 | | - context = pytest.raises(ValueError, match=": foo") |
49 | | - else: |
50 | | - context = does_not_raise() |
51 | | - |
52 | | - with context: |
53 | | - result = df.unique(subset, keep=keep).sort("z") # type: ignore[arg-type] |
54 | | - assert_equal_data(result, expected) |
| 61 | + result = df.unique(subset, keep=keep).sort("z") |
| 62 | + assert_equal_data(result, expected) |
55 | 63 |
|
56 | 64 |
|
57 | | -def test_unique_invalid_subset(constructor: Constructor) -> None: |
| 65 | +@pytest.mark.parametrize("subset", [None, ["a", "b"]]) |
| 66 | +@pytest.mark.parametrize( |
| 67 | + ("keep", "expected"), |
| 68 | + [ |
| 69 | + ("any", {"a": [1, 1, 2], "b": [3, 4, 4]}), |
| 70 | + ("none", {"a": [1, 2], "b": [4, 4]}), |
| 71 | + ], |
| 72 | +) |
| 73 | +def test_unique_full_subset( |
| 74 | + constructor: Constructor, |
| 75 | + subset: list[str] | None, |
| 76 | + keep: Literal["any", "none"], |
| 77 | + expected: dict[str, list[float]], |
| 78 | +) -> None: |
| 79 | + data = {"a": [1, 1, 1, 2], "b": [3, 3, 4, 4]} |
58 | 80 | df_raw = constructor(data) |
59 | 81 | df = nw.from_native(df_raw) |
60 | | - with pytest.raises(ColumnNotFoundError): |
61 | | - df.lazy().unique(["fdssfad"]).collect() |
| 82 | + result = df.unique(subset, keep=keep).sort("a", "b") |
| 83 | + assert_equal_data(result, expected) |
| 84 | + |
| 85 | + |
| 86 | +def test_unique_invalid_keep(constructor: Constructor) -> None: |
| 87 | + with pytest.raises(ValueError, match=r"(Got|got): cabbage"): |
| 88 | + nw.from_native(constructor(data)).unique(keep="cabbage") # type: ignore[arg-type] |
62 | 89 |
|
63 | 90 |
|
64 | 91 | @pytest.mark.filterwarnings("ignore:.*backwards-compatibility:UserWarning") |
|
0 commit comments