|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import platform |
4 | | -from typing import Literal, Protocol |
| 4 | +from typing import Any, Literal, Protocol, Sequence, cast |
5 | 5 |
|
6 | 6 | from playwright.sync_api import Locator, Page |
7 | 7 | from playwright.sync_api import expect as playwright_expect |
@@ -1211,7 +1211,6 @@ def click_loc(loc: Locator, *, shift: bool = False): |
1211 | 1211 | break |
1212 | 1212 | click_loc(sort_col, shift=shift) |
1213 | 1213 |
|
1214 | | - # TODO-karan-test: Add support for a list of columns ? If so, all other columns should be reset |
1215 | 1214 | def set_filter( |
1216 | 1215 | self, |
1217 | 1216 | # TODO-barret support array of filters |
@@ -1243,39 +1242,97 @@ def set_filter( |
1243 | 1242 | return |
1244 | 1243 |
|
1245 | 1244 | if isinstance(filter, dict): |
1246 | | - filter = [filter] |
1247 | | - |
1248 | | - if not isinstance(filter, list): |
| 1245 | + filter_items = [filter] |
| 1246 | + elif isinstance(filter, (list, tuple)): |
| 1247 | + filter_items = cast(Sequence[ColumnFilter | dict[str, Any]], filter) |
| 1248 | + else: |
1249 | 1249 | raise ValueError( |
1250 | | - "Invalid filter value. Must be a ColumnFilter, list[ColumnFilter], or None." |
| 1250 | + "Invalid filter value. Must be a ColumnFilter, " |
| 1251 | + "list[ColumnFilter], or None." |
1251 | 1252 | ) |
1252 | 1253 |
|
1253 | | - for filterInfo in filter: |
| 1254 | + for filterInfo in filter_items: |
1254 | 1255 | if "col" not in filterInfo: |
1255 | 1256 | raise ValueError("Column index (`col`) is required for filtering.") |
1256 | 1257 |
|
1257 | 1258 | if "value" not in filterInfo: |
1258 | 1259 | raise ValueError("Filter value (`value`) is required for filtering.") |
1259 | 1260 |
|
1260 | | - filterColumn = self.loc_column_filter.nth(filterInfo["col"]) |
| 1261 | + raw_cols = filterInfo["col"] |
| 1262 | + if isinstance(raw_cols, int): |
| 1263 | + column_indices: list[int] = [raw_cols] |
| 1264 | + elif isinstance(raw_cols, (list, tuple)): |
| 1265 | + cols_iter = cast(Sequence[Any], raw_cols) |
| 1266 | + if len(cols_iter) == 0: |
| 1267 | + raise ValueError( |
| 1268 | + "Column index list (`col`) must contain at least one " "entry." |
| 1269 | + ) |
| 1270 | + column_indices = [] |
| 1271 | + for col_idx in cols_iter: |
| 1272 | + if not isinstance(col_idx, int): |
| 1273 | + raise ValueError( |
| 1274 | + "Column index (`col`) values must be integers " |
| 1275 | + "when specifying multiple columns." |
| 1276 | + ) |
| 1277 | + column_indices.append(col_idx) |
| 1278 | + else: |
| 1279 | + raise ValueError( |
| 1280 | + "Column index (`col`) must be an int or a list/tuple of " "ints." |
| 1281 | + ) |
| 1282 | + |
| 1283 | + raw_value = filterInfo["value"] |
1261 | 1284 |
|
1262 | | - if isinstance(filterInfo["value"], str): |
1263 | | - filterColumn.locator("> input").fill(filterInfo["value"]) |
1264 | | - elif isinstance(filterInfo["value"], (tuple, list)): |
1265 | | - header_inputs = filterColumn.locator("> div > input") |
1266 | | - if filterInfo["value"][0] is not None: |
1267 | | - header_inputs.nth(0).fill( |
1268 | | - str(filterInfo["value"][0]), |
1269 | | - timeout=timeout, |
| 1285 | + if len(column_indices) > 1: |
| 1286 | + if not isinstance(raw_value, (list, tuple)): |
| 1287 | + raise ValueError( |
| 1288 | + "When filtering multiple columns, `value` must be a " |
| 1289 | + "list or tuple with one entry per column." |
1270 | 1290 | ) |
1271 | | - if filterInfo["value"][1] is not None: |
1272 | | - header_inputs.nth(1).fill( |
1273 | | - str(filterInfo["value"][1]), |
1274 | | - timeout=timeout, |
| 1291 | + value_sequence = cast(Sequence[Any], raw_value) |
| 1292 | + if len(value_sequence) != len(column_indices): |
| 1293 | + raise ValueError( |
| 1294 | + "The number of filter values must match the number of " |
| 1295 | + "target columns." |
1275 | 1296 | ) |
| 1297 | + value_iter = list(value_sequence) |
1276 | 1298 | else: |
| 1299 | + value_iter = [raw_value] |
| 1300 | + |
| 1301 | + for col_idx, value in zip(column_indices, value_iter): |
| 1302 | + filterColumn = self.loc_column_filter.nth(col_idx) |
| 1303 | + |
| 1304 | + if value is None: |
| 1305 | + continue |
| 1306 | + |
| 1307 | + if isinstance(value, (str, int, float)): |
| 1308 | + filterColumn.locator("> input").fill(str(value), timeout=timeout) |
| 1309 | + continue |
| 1310 | + |
| 1311 | + if isinstance(value, (list, tuple)): |
| 1312 | + range_values = cast(Sequence[Any], value) |
| 1313 | + if len(range_values) != 2: |
| 1314 | + raise ValueError( |
| 1315 | + "Numeric range filters must provide exactly two " |
| 1316 | + "values (min, max)." |
| 1317 | + ) |
| 1318 | + |
| 1319 | + header_inputs = filterColumn.locator("> div > input") |
| 1320 | + lower = range_values[0] |
| 1321 | + upper = range_values[1] |
| 1322 | + if lower is not None: |
| 1323 | + header_inputs.nth(0).fill(str(lower), timeout=timeout) |
| 1324 | + else: |
| 1325 | + header_inputs.nth(0).fill("", timeout=timeout) |
| 1326 | + |
| 1327 | + if upper is not None: |
| 1328 | + header_inputs.nth(1).fill(str(upper), timeout=timeout) |
| 1329 | + else: |
| 1330 | + header_inputs.nth(1).fill("", timeout=timeout) |
| 1331 | + continue |
| 1332 | + |
1277 | 1333 | raise ValueError( |
1278 | | - "Invalid filter value. Must be a string or a tuple/list of two numbers." |
| 1334 | + "Invalid filter value. Must be a string/number, a " |
| 1335 | + "tuple/list of two numbers, or None." |
1279 | 1336 | ) |
1280 | 1337 |
|
1281 | 1338 | def set_cell( |
|
0 commit comments