|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from collections.abc import Iterable, Sequence |
| 3 | +import operator |
| 4 | +from collections import deque |
| 5 | +from collections.abc import Collection, Iterable, Sequence |
4 | 6 |
|
5 | 7 | # ruff: noqa: A002 |
| 8 | +from functools import reduce |
6 | 9 | from itertools import chain |
7 | 10 | from typing import TYPE_CHECKING |
8 | 11 |
|
|
14 | 17 | is_iterable_reject, |
15 | 18 | is_selector, |
16 | 19 | ) |
| 20 | +from narwhals._plan.common import flatten_hash_safe |
17 | 21 | from narwhals._plan.exceptions import invalid_into_expr_error, is_iterable_error |
18 | 22 | from narwhals._utils import qualified_type_name |
19 | 23 | from narwhals.dependencies import get_polars |
|
27 | 31 |
|
28 | 32 | from narwhals._plan.expr import Expr |
29 | 33 | from narwhals._plan.expressions import ExprIR, SelectorIR |
| 34 | + from narwhals._plan.selectors import Selector |
30 | 35 | from narwhals._plan.typing import ( |
31 | 36 | ColumnNameOrSelector, |
32 | 37 | IntoExpr, |
@@ -129,19 +134,55 @@ def parse_into_expr_ir( |
129 | 134 | return expr._ir |
130 | 135 |
|
131 | 136 |
|
132 | | -def parse_into_selector_ir(input: ColumnNameOrSelector | Expr, /) -> SelectorIR: |
| 137 | +def parse_into_selector_ir( |
| 138 | + input: ColumnNameOrSelector | Expr, /, *, require_all: bool = True |
| 139 | +) -> SelectorIR: |
| 140 | + return _parse_into_selector(input, require_all=require_all)._ir |
| 141 | + |
| 142 | + |
| 143 | +def _parse_into_selector( |
| 144 | + input: ColumnNameOrSelector | Expr, /, *, require_all: bool = True |
| 145 | +) -> Selector: |
133 | 146 | if is_selector(input): |
134 | 147 | selector = input |
135 | 148 | elif isinstance(input, str): |
136 | | - from narwhals._plan import selectors as cs |
| 149 | + import narwhals._plan.selectors as cs |
137 | 150 |
|
138 | | - selector = cs.by_name(input) |
| 151 | + selector = cs.by_name(input, require_all=require_all) |
139 | 152 | elif is_expr(input): |
140 | 153 | selector = input.meta.as_selector() |
141 | 154 | else: |
142 | 155 | msg = f"cannot turn {qualified_type_name(input)!r} into a selector" |
143 | 156 | raise TypeError(msg) |
144 | | - return selector._ir |
| 157 | + return selector |
| 158 | + |
| 159 | + |
| 160 | +def parse_into_combined_selector_ir( |
| 161 | + *inputs: OneOrIterable[ColumnNameOrSelector], require_all: bool = True |
| 162 | +) -> SelectorIR: |
| 163 | + import narwhals._plan.selectors as cs |
| 164 | + |
| 165 | + flat = tuple(flatten_hash_safe(inputs)) |
| 166 | + selectors = deque["Selector"]() |
| 167 | + if names := tuple(el for el in flat if isinstance(el, str)): |
| 168 | + selector = cs.by_name(names, require_all=require_all) |
| 169 | + if len(names) == len(flat): |
| 170 | + return selector._ir |
| 171 | + selectors.append(selector) |
| 172 | + selectors.extend(_parse_into_selector(el) for el in flat if not isinstance(el, str)) |
| 173 | + return _any_of(selectors)._ir |
| 174 | + |
| 175 | + |
| 176 | +def _any_of(selectors: Iterable[Selector], /) -> Selector: |
| 177 | + import narwhals._plan.selectors as cs |
| 178 | + |
| 179 | + if isinstance(selectors, Collection): |
| 180 | + if not selectors: |
| 181 | + return cs.empty() |
| 182 | + if len(selectors) == 1: |
| 183 | + return next(iter(selectors)) # type: ignore[no-any-return] |
| 184 | + s: Selector = reduce(operator.or_, selectors) |
| 185 | + return s |
145 | 186 |
|
146 | 187 |
|
147 | 188 | def parse_into_seq_of_expr_ir( |
|
0 commit comments