|
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | import narwhals as nw |
8 | | -from narwhals._plan import selectors as ncs |
| 8 | +from narwhals._plan import Selector, selectors as ncs |
9 | 9 | from narwhals._utils import zip_strict |
10 | 10 | from tests.plan.utils import assert_equal_data, dataframe |
11 | 11 |
|
12 | 12 | if TYPE_CHECKING: |
13 | | - from narwhals._plan.typing import ColumnNameOrSelector |
| 13 | + from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable |
14 | 14 | from tests.conftest import Data |
15 | 15 |
|
16 | 16 |
|
@@ -53,3 +53,57 @@ def test_partition_by_single( |
53 | 53 | results = df.partition_by(by, include_key=include_key) |
54 | 54 | for df, expect in zip_strict(results, expected): |
55 | 55 | assert_equal_data(df, expect) |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.parametrize( |
| 59 | + ("include_key", "expected"), |
| 60 | + [ |
| 61 | + ( |
| 62 | + True, |
| 63 | + [ |
| 64 | + {"a": ["a", "a"], "b": [1, 1], "c": [5, 3]}, |
| 65 | + {"a": ["b"], "b": [2], "c": [4]}, |
| 66 | + {"a": ["b"], "b": [3], "c": [2]}, |
| 67 | + {"a": ["c"], "b": [3], "c": [1]}, |
| 68 | + ], |
| 69 | + ), |
| 70 | + (False, [{"c": [5, 3]}, {"c": [4]}, {"c": [2]}, {"c": [1]}]), |
| 71 | + ], |
| 72 | + ids=["include_key", "exclude_key"], |
| 73 | +) |
| 74 | +@pytest.mark.parametrize( |
| 75 | + ("by", "more_by"), |
| 76 | + [ |
| 77 | + ("a", "b"), |
| 78 | + (["a", "b"], ()), |
| 79 | + (ncs.matches("a|b"), ()), |
| 80 | + (ncs.string(), "b"), |
| 81 | + (ncs.by_name("a", "b"), ()), |
| 82 | + (ncs.by_name("b"), ncs.by_name("a")), |
| 83 | + (ncs.by_dtype(nw.String) | (ncs.numeric() - ncs.by_name("c")), []), |
| 84 | + ], |
| 85 | + ids=[ |
| 86 | + "str-variadic", |
| 87 | + "str-list", |
| 88 | + "ncs.matches", |
| 89 | + "ncs.string-str", |
| 90 | + "ncs.by_name", |
| 91 | + "2x-selector", |
| 92 | + "BinarySelector", |
| 93 | + ], |
| 94 | +) |
| 95 | +def test_partition_by_multiple( |
| 96 | + data: Data, |
| 97 | + by: ColumnNameOrSelector, |
| 98 | + more_by: OneOrIterable[ColumnNameOrSelector], |
| 99 | + *, |
| 100 | + include_key: bool, |
| 101 | + expected: Any, |
| 102 | +) -> None: |
| 103 | + df = dataframe(data) |
| 104 | + if isinstance(more_by, (str, Selector)): |
| 105 | + results = df.partition_by(by, more_by, include_key=include_key) |
| 106 | + else: |
| 107 | + results = df.partition_by(by, *more_by, include_key=include_key) |
| 108 | + for df, expect in zip_strict(results, expected): |
| 109 | + assert_equal_data(df, expect) |
0 commit comments