|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized |
6 | 5 | from itertools import chain |
7 | 6 | from typing import TYPE_CHECKING, Any, Literal, Protocol, overload |
8 | 7 |
|
9 | 8 | from narwhals._plan._expansion import prepare_projection |
10 | 9 | from narwhals._plan._parse import parse_into_seq_of_expr_ir |
11 | | -from narwhals._plan.common import flatten_hash_safe, replace, temp |
| 10 | +from narwhals._plan.common import replace, temp |
| 11 | +from narwhals._plan.compliant.column import EagerBroadcast, SupportsBroadcast |
12 | 12 | from narwhals._plan.compliant.typing import ( |
13 | 13 | ColumnT_co, |
14 | 14 | DataFrameT, |
|
34 | 34 | from narwhals.exceptions import ComputeError |
35 | 35 |
|
36 | 36 | if TYPE_CHECKING: |
| 37 | + from collections.abc import Iterable, Iterator, Mapping, Sequence |
| 38 | + |
37 | 39 | from typing_extensions import Self |
38 | 40 |
|
39 | 41 | from narwhals._plan import expressions as ir |
|
56 | 58 | from narwhals.typing import IntoDType, IntoSchema, PythonLiteral |
57 | 59 |
|
58 | 60 |
|
59 | | -class SupportsBroadcast(Protocol[SeriesT, LengthT]): |
60 | | - """Minimal broadcasting for `Expr` results.""" |
61 | | - |
62 | | - @classmethod |
63 | | - def from_series(cls, series: SeriesT, /) -> Self: ... |
64 | | - def to_series(self) -> SeriesT: ... |
65 | | - def broadcast(self, length: LengthT, /) -> SeriesT: ... |
66 | | - def _length(self) -> LengthT: |
67 | | - """Return the length of the current expression.""" |
68 | | - ... |
69 | | - |
70 | | - @classmethod |
71 | | - def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT: |
72 | | - """Return the maximum length among `exprs`.""" |
73 | | - ... |
74 | | - |
75 | | - @classmethod |
76 | | - def _length_required( |
77 | | - cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / |
78 | | - ) -> LengthT | None: |
79 | | - """Return the broadcast length, if all lengths do not equal the maximum.""" |
80 | | - |
81 | | - @classmethod |
82 | | - def _length_all( |
83 | | - cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], / |
84 | | - ) -> Sequence[LengthT]: |
85 | | - return [e._length() for e in exprs] |
86 | | - |
87 | | - @classmethod |
88 | | - def align( |
89 | | - cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]] |
90 | | - ) -> Iterator[SeriesT]: |
91 | | - exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs)) |
92 | | - length = cls._length_required(exprs) |
93 | | - if length is None: |
94 | | - for e in exprs: |
95 | | - yield e.to_series() |
96 | | - else: |
97 | | - for e in exprs: |
98 | | - yield e.broadcast(length) |
99 | | - |
100 | | - |
101 | | -class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]): |
102 | | - """Determines expression length via the size of the container.""" |
103 | | - |
104 | | - def _length(self) -> int: |
105 | | - return len(self) |
106 | | - |
107 | | - @classmethod |
108 | | - def _length_max(cls, lengths: Sequence[int], /) -> int: |
109 | | - return max(lengths) |
110 | | - |
111 | | - @classmethod |
112 | | - def _length_required( |
113 | | - cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], / |
114 | | - ) -> int | None: |
115 | | - lengths = cls._length_all(exprs) |
116 | | - max_length = cls._length_max(lengths) |
117 | | - required = any(len_ != max_length for len_ in lengths) |
118 | | - return max_length if required else None |
119 | | - |
120 | | - |
121 | 61 | class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]): |
122 | 62 | @classmethod |
123 | 63 | def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co: |
|
0 commit comments