Skip to content

Commit 2bffdaa

Browse files
committed
test: Add test_partition_by_multiple
1 parent e0d1a00 commit 2bffdaa

File tree

1 file changed

+56
-2
lines changed

1 file changed

+56
-2
lines changed

tests/plan/frame_partition_by_test.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import pytest
66

77
import narwhals as nw
8-
from narwhals._plan import selectors as ncs
8+
from narwhals._plan import Selector, selectors as ncs
99
from narwhals._utils import zip_strict
1010
from tests.plan.utils import assert_equal_data, dataframe
1111

1212
if TYPE_CHECKING:
13-
from narwhals._plan.typing import ColumnNameOrSelector
13+
from narwhals._plan.typing import ColumnNameOrSelector, OneOrIterable
1414
from tests.conftest import Data
1515

1616

@@ -53,3 +53,57 @@ def test_partition_by_single(
5353
results = df.partition_by(by, include_key=include_key)
5454
for df, expect in zip_strict(results, expected):
5555
assert_equal_data(df, expect)
56+
57+
58+
@pytest.mark.parametrize(
59+
("include_key", "expected"),
60+
[
61+
(
62+
True,
63+
[
64+
{"a": ["a", "a"], "b": [1, 1], "c": [5, 3]},
65+
{"a": ["b"], "b": [2], "c": [4]},
66+
{"a": ["b"], "b": [3], "c": [2]},
67+
{"a": ["c"], "b": [3], "c": [1]},
68+
],
69+
),
70+
(False, [{"c": [5, 3]}, {"c": [4]}, {"c": [2]}, {"c": [1]}]),
71+
],
72+
ids=["include_key", "exclude_key"],
73+
)
74+
@pytest.mark.parametrize(
75+
("by", "more_by"),
76+
[
77+
("a", "b"),
78+
(["a", "b"], ()),
79+
(ncs.matches("a|b"), ()),
80+
(ncs.string(), "b"),
81+
(ncs.by_name("a", "b"), ()),
82+
(ncs.by_name("b"), ncs.by_name("a")),
83+
(ncs.by_dtype(nw.String) | (ncs.numeric() - ncs.by_name("c")), []),
84+
],
85+
ids=[
86+
"str-variadic",
87+
"str-list",
88+
"ncs.matches",
89+
"ncs.string-str",
90+
"ncs.by_name",
91+
"2x-selector",
92+
"BinarySelector",
93+
],
94+
)
95+
def test_partition_by_multiple(
96+
data: Data,
97+
by: ColumnNameOrSelector,
98+
more_by: OneOrIterable[ColumnNameOrSelector],
99+
*,
100+
include_key: bool,
101+
expected: Any,
102+
) -> None:
103+
df = dataframe(data)
104+
if isinstance(more_by, (str, Selector)):
105+
results = df.partition_by(by, more_by, include_key=include_key)
106+
else:
107+
results = df.partition_by(by, *more_by, include_key=include_key)
108+
for df, expect in zip_strict(results, expected):
109+
assert_equal_data(df, expect)

0 commit comments

Comments
 (0)