66from playwright .sync_api import Locator , Page
77from playwright .sync_api import expect as playwright_expect
88
9- from ...render ._data_frame import ColumnFilter , ColumnSort
9+ from shiny .types import ListOrTuple
10+
11+ from ...render ._data_frame import ColumnFilter , ColumnSort , assert_column_filters
1012from .._types import AttrValue , ListPatternOrStr , PatternOrStr , StyleValue , Timeout
1113from ..expect import expect_not_to_have_class , expect_to_have_class
1214from ..expect ._internal import (
@@ -1213,7 +1215,7 @@ def click_loc(loc: Locator, *, shift: bool = False):
12131215
12141216 def set_filter (
12151217 self ,
1216- filter : ColumnFilter | list [ColumnFilter ] | None ,
1218+ filter : ColumnFilter | ListOrTuple [ColumnFilter ] | None ,
12171219 * ,
12181220 timeout : Timeout = None ,
12191221 ):
@@ -1228,6 +1230,7 @@ def set_filter(
12281230 * `None`: Resets all filters.
12291231 * `ColumnFilterStr`: A dictionary specifying a string filter with 'col' and 'value' keys.
12301232 * `ColumnFilterNumber`: A dictionary specifying a numeric range filter with 'col' and 'value' keys.
1233+ * A sequence of `ColumnFilterStr` or `ColumnFilterNumber` dictionaries, for multiple filters.
12311234 timeout
12321235 The maximum time to wait for the action to complete. Defaults to `None`.
12331236 """
@@ -1240,100 +1243,52 @@ def set_filter(
12401243 if filter is None :
12411244 return
12421245
1243- filter_items : Sequence [ Union [ ColumnFilter , dict [ str , Any ]] ]
1246+ filter_items : ListOrTuple [ ColumnFilter ]
12441247 if isinstance (filter , dict ):
12451248 filter_items = [filter ]
12461249 elif isinstance (filter , (list , tuple )):
1247- filter_items = cast ( Sequence [ Union [ ColumnFilter , dict [ str , Any ]]], filter )
1250+ filter_items = filter
12481251 else :
12491252 raise ValueError (
12501253 "Invalid filter value. Must be a ColumnFilter, "
12511254 "list[ColumnFilter], or None."
12521255 )
12531256
1254- for filterInfo in filter_items :
1255- if "col" not in filterInfo :
1256- raise ValueError ("Column index (`col`) is required for filtering." )
1257+ assert_column_filters (filter_items , self .loc_column_label .count ())
12571258
1258- if "value" not in filterInfo :
1259- raise ValueError ("Filter value (`value`) is required for filtering." )
1259+ for filter_item in filter_items :
1260+ col_idx = filter_item ["col" ]
1261+ value = filter_item .get ("value" , None )
12601262
1261- raw_cols = filterInfo ["col" ]
1262- if isinstance (raw_cols , int ):
1263- column_indices : list [int ] = [raw_cols ]
1264- elif isinstance (raw_cols , (list , tuple )):
1265- cols_iter = cast (Sequence [Any ], raw_cols )
1266- if len (cols_iter ) == 0 :
1267- raise ValueError (
1268- "Column index list (`col`) must contain at least one " "entry."
1269- )
1270- column_indices = []
1271- for col_idx in cols_iter :
1272- if not isinstance (col_idx , int ):
1273- raise ValueError (
1274- "Column index (`col`) values must be integers "
1275- "when specifying multiple columns."
1276- )
1277- column_indices .append (col_idx )
1278- else :
1279- raise ValueError (
1280- "Column index (`col`) must be an int or a list/tuple of " "ints."
1281- )
1263+ filterColumn = self .loc_column_filter .nth (col_idx )
12821264
1283- raw_value = filterInfo ["value" ]
1265+ if isinstance (value , (str , int , float )):
1266+ filterColumn .locator ("> input" ).fill (str (value ), timeout = timeout )
1267+ continue
12841268
1285- if len (column_indices ) > 1 :
1286- if not isinstance (raw_value , (list , tuple )):
1287- raise ValueError (
1288- "When filtering multiple columns, `value` must be a "
1289- "list or tuple with one entry per column."
1290- )
1291- value_sequence = cast (Sequence [Any ], raw_value )
1292- if len (value_sequence ) != len (column_indices ):
1269+ if isinstance (value , (list , tuple )):
1270+ range_values = cast (Sequence [Any ], value )
1271+ if len (range_values ) != 2 :
12931272 raise ValueError (
1294- "The number of filter values must match the number of "
1295- "target columns ."
1273+ "Numeric range filters must provide exactly two "
1274+ "values (min, max) ."
12961275 )
1297- value_iter = list (value_sequence )
1298- else :
1299- value_iter = [raw_value ]
1300-
1301- for col_idx , value in zip (column_indices , value_iter ):
1302- filterColumn = self .loc_column_filter .nth (col_idx )
1303-
1304- if value is None :
1305- continue
1306-
1307- if isinstance (value , (str , int , float )):
1308- filterColumn .locator ("> input" ).fill (str (value ), timeout = timeout )
1309- continue
1310-
1311- if isinstance (value , (list , tuple )):
1312- range_values = cast (Sequence [Any ], value )
1313- if len (range_values ) != 2 :
1314- raise ValueError (
1315- "Numeric range filters must provide exactly two "
1316- "values (min, max)."
1317- )
1318-
1319- header_inputs = filterColumn .locator ("> div > input" )
1320- lower = range_values [0 ]
1321- upper = range_values [1 ]
1322- if lower is not None :
1323- header_inputs .nth (0 ).fill (str (lower ), timeout = timeout )
1324- else :
1325- header_inputs .nth (0 ).fill ("" , timeout = timeout )
1326-
1327- if upper is not None :
1328- header_inputs .nth (1 ).fill (str (upper ), timeout = timeout )
1329- else :
1330- header_inputs .nth (1 ).fill ("" , timeout = timeout )
1331- continue
1332-
1333- raise ValueError (
1334- "Invalid filter value. Must be a string/number, a "
1335- "tuple/list of two numbers, or None."
1336- )
1276+
1277+ header_inputs = filterColumn .locator ("> div > input" )
1278+ lower = range_values [0 ]
1279+ upper = range_values [1 ]
1280+ if lower is not None :
1281+ header_inputs .nth (0 ).fill (str (lower ), timeout = timeout )
1282+ else :
1283+ header_inputs .nth (0 ).fill ("" , timeout = timeout )
1284+
1285+ if upper is not None :
1286+ header_inputs .nth (1 ).fill (str (upper ), timeout = timeout )
1287+ else :
1288+ header_inputs .nth (1 ).fill ("" , timeout = timeout )
1289+ continue
1290+
1291+ raise ValueError ("Invalid filter value." )
13371292
13381293 def set_cell (
13391294 self ,
0 commit comments