Skip to content

Commit c5dd290

Browse files
committed
Enhance set_filter to support multi-column filtering
1 parent 80018f2 commit c5dd290

File tree

2 files changed

+193
-21
lines changed

2 files changed

+193
-21
lines changed

shiny/playwright/controller/_output.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import platform
4-
from typing import Literal, Protocol
4+
from typing import Any, Literal, Protocol, Sequence, cast
55

66
from playwright.sync_api import Locator, Page
77
from playwright.sync_api import expect as playwright_expect
@@ -1211,7 +1211,6 @@ def click_loc(loc: Locator, *, shift: bool = False):
12111211
break
12121212
click_loc(sort_col, shift=shift)
12131213

1214-
# TODO-karan-test: Add support for a list of columns ? If so, all other columns should be reset
12151214
def set_filter(
12161215
self,
12171216
# TODO-barret support array of filters
@@ -1243,39 +1242,97 @@ def set_filter(
12431242
return
12441243

12451244
if isinstance(filter, dict):
1246-
filter = [filter]
1247-
1248-
if not isinstance(filter, list):
1245+
filter_items = [filter]
1246+
elif isinstance(filter, (list, tuple)):
1247+
filter_items = cast(Sequence[ColumnFilter | dict[str, Any]], filter)
1248+
else:
12491249
raise ValueError(
1250-
"Invalid filter value. Must be a ColumnFilter, list[ColumnFilter], or None."
1250+
"Invalid filter value. Must be a ColumnFilter, "
1251+
"list[ColumnFilter], or None."
12511252
)
12521253

1253-
for filterInfo in filter:
1254+
for filterInfo in filter_items:
12541255
if "col" not in filterInfo:
12551256
raise ValueError("Column index (`col`) is required for filtering.")
12561257

12571258
if "value" not in filterInfo:
12581259
raise ValueError("Filter value (`value`) is required for filtering.")
12591260

1260-
filterColumn = self.loc_column_filter.nth(filterInfo["col"])
1261+
raw_cols = filterInfo["col"]
1262+
if isinstance(raw_cols, int):
1263+
column_indices: list[int] = [raw_cols]
1264+
elif isinstance(raw_cols, (list, tuple)):
1265+
cols_iter = cast(Sequence[Any], raw_cols)
1266+
if len(cols_iter) == 0:
1267+
raise ValueError(
1268+
"Column index list (`col`) must contain at least one " "entry."
1269+
)
1270+
column_indices = []
1271+
for col_idx in cols_iter:
1272+
if not isinstance(col_idx, int):
1273+
raise ValueError(
1274+
"Column index (`col`) values must be integers "
1275+
"when specifying multiple columns."
1276+
)
1277+
column_indices.append(col_idx)
1278+
else:
1279+
raise ValueError(
1280+
"Column index (`col`) must be an int or a list/tuple of " "ints."
1281+
)
1282+
1283+
raw_value = filterInfo["value"]
12611284

1262-
if isinstance(filterInfo["value"], str):
1263-
filterColumn.locator("> input").fill(filterInfo["value"])
1264-
elif isinstance(filterInfo["value"], (tuple, list)):
1265-
header_inputs = filterColumn.locator("> div > input")
1266-
if filterInfo["value"][0] is not None:
1267-
header_inputs.nth(0).fill(
1268-
str(filterInfo["value"][0]),
1269-
timeout=timeout,
1285+
if len(column_indices) > 1:
1286+
if not isinstance(raw_value, (list, tuple)):
1287+
raise ValueError(
1288+
"When filtering multiple columns, `value` must be a "
1289+
"list or tuple with one entry per column."
12701290
)
1271-
if filterInfo["value"][1] is not None:
1272-
header_inputs.nth(1).fill(
1273-
str(filterInfo["value"][1]),
1274-
timeout=timeout,
1291+
value_sequence = cast(Sequence[Any], raw_value)
1292+
if len(value_sequence) != len(column_indices):
1293+
raise ValueError(
1294+
"The number of filter values must match the number of "
1295+
"target columns."
12751296
)
1297+
value_iter = list(value_sequence)
12761298
else:
1299+
value_iter = [raw_value]
1300+
1301+
for col_idx, value in zip(column_indices, value_iter):
1302+
filterColumn = self.loc_column_filter.nth(col_idx)
1303+
1304+
if value is None:
1305+
continue
1306+
1307+
if isinstance(value, (str, int, float)):
1308+
filterColumn.locator("> input").fill(str(value), timeout=timeout)
1309+
continue
1310+
1311+
if isinstance(value, (list, tuple)):
1312+
range_values = cast(Sequence[Any], value)
1313+
if len(range_values) != 2:
1314+
raise ValueError(
1315+
"Numeric range filters must provide exactly two "
1316+
"values (min, max)."
1317+
)
1318+
1319+
header_inputs = filterColumn.locator("> div > input")
1320+
lower = range_values[0]
1321+
upper = range_values[1]
1322+
if lower is not None:
1323+
header_inputs.nth(0).fill(str(lower), timeout=timeout)
1324+
else:
1325+
header_inputs.nth(0).fill("", timeout=timeout)
1326+
1327+
if upper is not None:
1328+
header_inputs.nth(1).fill(str(upper), timeout=timeout)
1329+
else:
1330+
header_inputs.nth(1).fill("", timeout=timeout)
1331+
continue
1332+
12771333
raise ValueError(
1278-
"Invalid filter value. Must be a string or a tuple/list of two numbers."
1334+
"Invalid filter value. Must be a string/number, a "
1335+
"tuple/list of two numbers, or None."
12791336
)
12801337

12811338
def set_cell(

tests/playwright/shiny/components/data_frame/filter_reset/test_filter_reset.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,118 @@ def test_filters_are_reset(page: Page, local_app: ShinyAppProc) -> None:
4040
penguin_code.expect_value("()")
4141
for i in range(filter_inputs.count()):
4242
expect(filter_inputs.nth(i)).to_have_value("")
43+
44+
45+
def test_set_filter_accepts_single_dict(page: Page, local_app: ShinyAppProc) -> None:
46+
page.goto(local_app.url)
47+
48+
penguin_df = controller.OutputDataFrame(page, "penguins_df")
49+
penguin_code = controller.OutputCode(page, "penguins_code")
50+
51+
penguin_df.set_filter({"col": 0, "value": "Chinstrap"})
52+
53+
penguin_code.expect_value("({'col': 0, 'value': 'Chinstrap'},)")
54+
expect(penguin_df.loc_column_filter.nth(0).locator("input")).to_have_value(
55+
"Chinstrap"
56+
)
57+
58+
59+
def test_set_filter_accepts_list(page: Page, local_app: ShinyAppProc) -> None:
60+
page.goto(local_app.url)
61+
62+
penguin_df = controller.OutputDataFrame(page, "penguins_df")
63+
penguin_code = controller.OutputCode(page, "penguins_code")
64+
65+
penguin_df.set_filter(
66+
[
67+
{"col": 0, "value": "Gentoo"},
68+
{"col": 2, "value": (45, 50)},
69+
]
70+
)
71+
72+
penguin_code.expect_value(
73+
"({'col': 0, 'value': 'Gentoo'}, {'col': 2, 'value': (45, 50)})"
74+
)
75+
76+
species_filter = penguin_df.loc_column_filter.nth(0).locator("input")
77+
length_filter = penguin_df.loc_column_filter.nth(2).locator("div > input")
78+
79+
expect(species_filter).to_have_value("Gentoo")
80+
expect(length_filter.nth(0)).to_have_value("45")
81+
expect(length_filter.nth(1)).to_have_value("50")
82+
83+
84+
def test_set_filter_accepts_tuple(page: Page, local_app: ShinyAppProc) -> None:
85+
page.goto(local_app.url)
86+
87+
penguin_df = controller.OutputDataFrame(page, "penguins_df")
88+
penguin_code = controller.OutputCode(page, "penguins_code")
89+
90+
penguin_df.set_filter(
91+
(
92+
{"col": 0, "value": "Adelie"},
93+
{"col": 3, "value": (None, 17)},
94+
) # type: ignore[arg-type]
95+
)
96+
97+
penguin_code.expect_value(
98+
"({'col': 0, 'value': 'Adelie'}, {'col': 3, 'value': (None, 17)})"
99+
)
100+
101+
species_filter = penguin_df.loc_column_filter.nth(0).locator("input")
102+
depth_filter = penguin_df.loc_column_filter.nth(3).locator("div > input")
103+
104+
expect(species_filter).to_have_value("Adelie")
105+
expect(depth_filter.nth(0)).to_have_value("")
106+
expect(depth_filter.nth(1)).to_have_value("17")
107+
108+
109+
def test_set_filter_accepts_multi_column_mapping(
110+
page: Page, local_app: ShinyAppProc
111+
) -> None:
112+
page.goto(local_app.url)
113+
114+
penguin_df = controller.OutputDataFrame(page, "penguins_df")
115+
penguin_code = controller.OutputCode(page, "penguins_code")
116+
117+
penguin_df.set_filter(
118+
{
119+
"col": [0, 1],
120+
"value": ["Gentoo", "Biscoe"],
121+
} # type: ignore[arg-type]
122+
)
123+
124+
penguin_code.expect_value(
125+
"({'col': 0, 'value': 'Gentoo'}, {'col': 1, 'value': 'Biscoe'})"
126+
)
127+
128+
species_filter = penguin_df.loc_column_filter.nth(0).locator("input")
129+
island_filter = penguin_df.loc_column_filter.nth(1).locator("input")
130+
131+
expect(species_filter).to_have_value("Gentoo")
132+
expect(island_filter).to_have_value("Biscoe")
133+
134+
135+
def test_set_filter_none_clears_inputs(page: Page, local_app: ShinyAppProc) -> None:
136+
page.goto(local_app.url)
137+
138+
penguin_df = controller.OutputDataFrame(page, "penguins_df")
139+
penguin_code = controller.OutputCode(page, "penguins_code")
140+
141+
penguin_df.set_filter(
142+
[
143+
{"col": 0, "value": "Gentoo"},
144+
{"col": 2, "value": (45, 50)},
145+
]
146+
)
147+
148+
penguin_code.expect_value(
149+
"({'col': 0, 'value': 'Gentoo'}, {'col': 2, 'value': (45, 50)})"
150+
)
151+
penguin_df.set_filter(None)
152+
153+
penguin_code.expect_value("()")
154+
155+
filter_inputs = penguin_df.loc_column_filter.locator("input")
156+
for i in range(filter_inputs.count()):
157+
expect(filter_inputs.nth(i)).to_have_value("")

0 commit comments

Comments
 (0)