Skip to content

Commit 2791431

Browse files
committed
refactor: Move *Broadcast to column.py`
1 parent 87db910 commit 2791431

File tree

2 files changed

+79
-64
lines changed

2 files changed

+79
-64
lines changed

narwhals/_plan/compliant/column.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,76 @@
11
from __future__ import annotations
2+
3+
from collections.abc import Sequence, Sized
4+
from typing import TYPE_CHECKING, Protocol
5+
6+
from narwhals._plan.common import flatten_hash_safe
7+
from narwhals._plan.compliant.typing import LengthT, SeriesT
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Iterator, Sequence
11+
12+
from typing_extensions import Self
13+
14+
from narwhals._plan.typing import OneOrIterable
15+
16+
17+
class SupportsBroadcast(Protocol[SeriesT, LengthT]):
18+
"""Minimal broadcasting for `Expr` results."""
19+
20+
@classmethod
21+
def from_series(cls, series: SeriesT, /) -> Self: ...
22+
def to_series(self) -> SeriesT: ...
23+
def broadcast(self, length: LengthT, /) -> SeriesT: ...
24+
def _length(self) -> LengthT:
25+
"""Return the length of the current expression."""
26+
...
27+
28+
@classmethod
29+
def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT:
30+
"""Return the maximum length among `exprs`."""
31+
...
32+
33+
@classmethod
34+
def _length_required(
35+
cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], /
36+
) -> LengthT | None:
37+
"""Return the broadcast length, if all lengths do not equal the maximum."""
38+
39+
@classmethod
40+
def _length_all(
41+
cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], /
42+
) -> Sequence[LengthT]:
43+
return [e._length() for e in exprs]
44+
45+
@classmethod
46+
def align(
47+
cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]]
48+
) -> Iterator[SeriesT]:
49+
exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs))
50+
length = cls._length_required(exprs)
51+
if length is None:
52+
for e in exprs:
53+
yield e.to_series()
54+
else:
55+
for e in exprs:
56+
yield e.broadcast(length)
57+
58+
59+
class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]):
60+
"""Determines expression length via the size of the container."""
61+
62+
def _length(self) -> int:
63+
return len(self)
64+
65+
@classmethod
66+
def _length_max(cls, lengths: Sequence[int], /) -> int:
67+
return max(lengths)
68+
69+
@classmethod
70+
def _length_required(
71+
cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], /
72+
) -> int | None:
73+
lengths = cls._length_all(exprs)
74+
max_length = cls._length_max(lengths)
75+
required = any(len_ != max_length for len_ in lengths)
76+
return max_length if required else None

narwhals/_plan/protocols.py

Lines changed: 4 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized
65
from itertools import chain
76
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload
87

98
from narwhals._plan._expansion import prepare_projection
109
from narwhals._plan._parse import parse_into_seq_of_expr_ir
11-
from narwhals._plan.common import flatten_hash_safe, replace, temp
10+
from narwhals._plan.common import replace, temp
11+
from narwhals._plan.compliant.column import EagerBroadcast, SupportsBroadcast
1212
from narwhals._plan.compliant.typing import (
1313
ColumnT_co,
1414
DataFrameT,
@@ -34,6 +34,8 @@
3434
from narwhals.exceptions import ComputeError
3535

3636
if TYPE_CHECKING:
37+
from collections.abc import Iterable, Iterator, Mapping, Sequence
38+
3739
from typing_extensions import Self
3840

3941
from narwhals._plan import expressions as ir
@@ -56,68 +58,6 @@
5658
from narwhals.typing import IntoDType, IntoSchema, PythonLiteral
5759

5860

59-
class SupportsBroadcast(Protocol[SeriesT, LengthT]):
60-
"""Minimal broadcasting for `Expr` results."""
61-
62-
@classmethod
63-
def from_series(cls, series: SeriesT, /) -> Self: ...
64-
def to_series(self) -> SeriesT: ...
65-
def broadcast(self, length: LengthT, /) -> SeriesT: ...
66-
def _length(self) -> LengthT:
67-
"""Return the length of the current expression."""
68-
...
69-
70-
@classmethod
71-
def _length_max(cls, lengths: Sequence[LengthT], /) -> LengthT:
72-
"""Return the maximum length among `exprs`."""
73-
...
74-
75-
@classmethod
76-
def _length_required(
77-
cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], /
78-
) -> LengthT | None:
79-
"""Return the broadcast length, if all lengths do not equal the maximum."""
80-
81-
@classmethod
82-
def _length_all(
83-
cls, exprs: Sequence[SupportsBroadcast[SeriesT, LengthT]], /
84-
) -> Sequence[LengthT]:
85-
return [e._length() for e in exprs]
86-
87-
@classmethod
88-
def align(
89-
cls, *exprs: OneOrIterable[SupportsBroadcast[SeriesT, LengthT]]
90-
) -> Iterator[SeriesT]:
91-
exprs = tuple[SupportsBroadcast[SeriesT, LengthT], ...](flatten_hash_safe(exprs))
92-
length = cls._length_required(exprs)
93-
if length is None:
94-
for e in exprs:
95-
yield e.to_series()
96-
else:
97-
for e in exprs:
98-
yield e.broadcast(length)
99-
100-
101-
class EagerBroadcast(Sized, SupportsBroadcast[SeriesT, int], Protocol[SeriesT]):
102-
"""Determines expression length via the size of the container."""
103-
104-
def _length(self) -> int:
105-
return len(self)
106-
107-
@classmethod
108-
def _length_max(cls, lengths: Sequence[int], /) -> int:
109-
return max(lengths)
110-
111-
@classmethod
112-
def _length_required(
113-
cls, exprs: Sequence[SupportsBroadcast[SeriesT, int]], /
114-
) -> int | None:
115-
lengths = cls._length_all(exprs)
116-
max_length = cls._length_max(lengths)
117-
required = any(len_ != max_length for len_ in lengths)
118-
return max_length if required else None
119-
120-
12161
class ExprDispatch(StoresVersion, Protocol[FrameT_contra, R_co, NamespaceT_co]):
12262
@classmethod
12363
def from_ir(cls, node: ir.ExprIR, frame: FrameT_contra, name: str) -> R_co:

0 commit comments

Comments
 (0)