Skip to content

Commit 3ce5e29

Browse files
authored
perf: avoid full broadcast in horizontal functions (#3199)
1 parent fcce5ca commit 3ce5e29

File tree

8 files changed

+89
-111
lines changed

8 files changed

+89
-111
lines changed

narwhals/_arrow/namespace.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
if TYPE_CHECKING:
2424
from collections.abc import Iterator, Sequence
2525

26-
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
26+
from narwhals._arrow.typing import (
27+
ArrayOrScalar,
28+
ChunkedArrayAny,
29+
Incomplete,
30+
ScalarAny,
31+
)
2732
from narwhals._compliant.typing import ScalarKwargs
2833
from narwhals._utils import Version
2934
from narwhals.typing import IntoDType, NonNestedLiteral
@@ -49,6 +54,11 @@ def _series(self) -> type[ArrowSeries]:
4954
def __init__(self, *, version: Version) -> None:
5055
self._version = version
5156

57+
def extract_native(
58+
self, *series: ArrowSeries
59+
) -> Iterator[ChunkedArrayAny | ScalarAny]:
60+
return (s.native[0] if s._broadcast else s.native for s in series)
61+
5262
def len(self) -> ArrowExpr:
5363
# coverage bug? this is definitely hit
5464
return self._expr( # pragma: no cover
@@ -83,10 +93,9 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
8393
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
8494
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
8595
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
86-
align = self._series._align_full_broadcast
8796
if ignore_nulls:
8897
series = (s.fill_null(True, None, None) for s in series)
89-
return [reduce(operator.and_, align(*series))]
98+
return [reduce(operator.and_, series)]
9099

91100
return self._expr._from_callable(
92101
func=func,
@@ -100,10 +109,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
100109
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
101110
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
102111
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
103-
align = self._series._align_full_broadcast
104112
if ignore_nulls:
105113
series = (s.fill_null(False, None, None) for s in series)
106-
return [reduce(operator.or_, align(*series))]
114+
return [reduce(operator.or_, series)]
107115

108116
return self._expr._from_callable(
109117
func=func,
@@ -118,8 +126,7 @@ def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
118126
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
119127
it = chain.from_iterable(expr(df) for expr in exprs)
120128
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
121-
align = self._series._align_full_broadcast
122-
return [reduce(operator.add, align(*series))]
129+
return [reduce(operator.add, series)]
123130

124131
return self._expr._from_callable(
125132
func=func,
@@ -135,11 +142,8 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
135142

136143
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
137144
expr_results = tuple(chain.from_iterable(expr(df) for expr in exprs))
138-
align = self._series._align_full_broadcast
139-
series = align(
140-
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
141-
)
142-
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
145+
series = [s.fill_null(0, strategy=None, limit=None) for s in expr_results]
146+
non_na = [1 - s.is_null().cast(int_64) for s in expr_results]
143147
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
144148

145149
return self._expr._from_callable(
@@ -153,9 +157,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
153157

154158
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
155159
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
156-
align = self._series._align_full_broadcast
157160
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
158-
init_series, *series = align(init_series, *series)
159161
native_series = reduce(
160162
pc.min_element_wise, [s.native for s in series], init_series.native
161163
)
@@ -174,9 +176,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
174176

175177
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
176178
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
177-
align = self._series._align_full_broadcast
178179
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
179-
init_series, *series = align(init_series, *series)
180180
native_series = reduce(
181181
pc.max_element_wise, [s.native for s in series], init_series.native
182182
)
@@ -227,16 +227,13 @@ def concat_str(
227227
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
228228
) -> ArrowExpr:
229229
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
230-
align = self._series._align_full_broadcast
231-
compliant_series_list = align(
232-
*(chain.from_iterable(expr(df) for expr in exprs))
233-
)
234-
name = compliant_series_list[0].name
230+
series = list(chain.from_iterable(expr(df) for expr in exprs))
231+
name = series[0].name
235232
null_handling: Literal["skip", "emit_null"] = (
236233
"skip" if ignore_nulls else "emit_null"
237234
)
238235
it, separator_scalar = cast_to_comparable_string_types(
239-
*(s.native for s in compliant_series_list), separator=separator
236+
*self.extract_native(*series), separator=separator
240237
)
241238
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
242239
# Reality: `str` is fine

narwhals/_arrow/series.py

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Literal, cast, overload
3+
from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload
44

55
import pyarrow as pa
66
import pyarrow.compute as pc
@@ -148,6 +148,16 @@ def _with_native(
148148
result._broadcast = self._broadcast
149149
return result
150150

151+
def _with_binary(self, op: Callable[..., ArrayOrScalar], other: Any) -> Self:
152+
ser, other_native = extract_native(self, other)
153+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
154+
return self._with_native(
155+
op(ser, other_native), preserve_broadcast=preserve_broadcast
156+
).alias(self.name)
157+
158+
def _with_binary_right(self, op: Callable[..., ArrayOrScalar], other: Any) -> Self:
159+
return self._with_binary(lambda x, y: op(y, x), other).alias(self.name)
160+
151161
@classmethod
152162
def from_iterable(
153163
cls,
@@ -214,106 +224,89 @@ def __narwhals_namespace__(self) -> ArrowNamespace:
214224
return ArrowNamespace(version=self._version)
215225

216226
def __eq__(self, other: object) -> Self: # type: ignore[override]
217-
other = cast("PythonLiteral | ArrowSeries | None", other)
218-
ser, rhs = extract_native(self, other)
219-
return self._with_native(pc.equal(ser, rhs))
227+
return self._with_binary(pc.equal, other)
220228

221229
def __ne__(self, other: object) -> Self: # type: ignore[override]
222-
other = cast("PythonLiteral | ArrowSeries | None", other)
223-
ser, rhs = extract_native(self, other)
224-
return self._with_native(pc.not_equal(ser, rhs))
230+
return self._with_binary(pc.not_equal, other)
225231

226232
def __ge__(self, other: Any) -> Self:
227-
ser, other = extract_native(self, other)
228-
return self._with_native(pc.greater_equal(ser, other))
233+
return self._with_binary(pc.greater_equal, other)
229234

230235
def __gt__(self, other: Any) -> Self:
231-
ser, other = extract_native(self, other)
232-
return self._with_native(pc.greater(ser, other))
236+
return self._with_binary(pc.greater, other)
233237

234238
def __le__(self, other: Any) -> Self:
235-
ser, other = extract_native(self, other)
236-
return self._with_native(pc.less_equal(ser, other))
239+
return self._with_binary(pc.less_equal, other)
237240

238241
def __lt__(self, other: Any) -> Self:
239-
ser, other = extract_native(self, other)
240-
return self._with_native(pc.less(ser, other))
242+
return self._with_binary(pc.less, other)
241243

242244
def __and__(self, other: Any) -> Self:
243-
ser, other = extract_native(self, other)
244-
return self._with_native(pc.and_kleene(ser, other)) # type: ignore[arg-type]
245+
return self._with_binary(pc.and_kleene, other)
245246

246247
def __rand__(self, other: Any) -> Self:
247-
ser, other = extract_native(self, other)
248-
return self._with_native(pc.and_kleene(other, ser)) # type: ignore[arg-type]
248+
return self._with_binary_right(pc.and_kleene, other)
249249

250250
def __or__(self, other: Any) -> Self:
251-
ser, other = extract_native(self, other)
252-
return self._with_native(pc.or_kleene(ser, other)) # type: ignore[arg-type]
251+
return self._with_binary_right(pc.or_kleene, other)
253252

254253
def __ror__(self, other: Any) -> Self:
255-
ser, other = extract_native(self, other)
256-
return self._with_native(pc.or_kleene(other, ser)) # type: ignore[arg-type]
254+
return self._with_binary_right(pc.or_kleene, other)
257255

258256
def __add__(self, other: Any) -> Self:
259-
ser, other = extract_native(self, other)
260-
return self._with_native(pc.add(ser, other))
257+
return self._with_binary(pc.add, other)
261258

262259
def __radd__(self, other: Any) -> Self:
263-
return self + other
260+
return self._with_binary_right(pc.add, other)
264261

265262
def __sub__(self, other: Any) -> Self:
266-
ser, other = extract_native(self, other)
267-
return self._with_native(pc.subtract(ser, other))
263+
return self._with_binary(pc.subtract, other)
268264

269265
def __rsub__(self, other: Any) -> Self:
270-
return (self - other) * (-1)
266+
return self._with_binary_right(pc.subtract, other)
271267

272268
def __mul__(self, other: Any) -> Self:
273-
ser, other = extract_native(self, other)
274-
return self._with_native(pc.multiply(ser, other))
269+
return self._with_binary(pc.multiply, other)
275270

276271
def __rmul__(self, other: Any) -> Self:
277-
return self * other
272+
return self._with_binary_right(pc.multiply, other)
278273

279274
def __pow__(self, other: Any) -> Self:
280-
ser, other = extract_native(self, other)
281-
return self._with_native(pc.power(ser, other))
275+
return self._with_binary(pc.power, other)
282276

283277
def __rpow__(self, other: Any) -> Self:
284-
ser, other = extract_native(self, other)
285-
return self._with_native(pc.power(other, ser))
278+
return self._with_binary_right(pc.power, other)
286279

287280
def __floordiv__(self, other: Any) -> Self:
288-
ser, other = extract_native(self, other)
289-
return self._with_native(floordiv_compat(ser, other))
281+
return self._with_binary(floordiv_compat, other)
290282

291283
def __rfloordiv__(self, other: Any) -> Self:
292-
ser, other = extract_native(self, other)
293-
return self._with_native(floordiv_compat(other, ser))
284+
return self._with_binary_right(floordiv_compat, other)
294285

295286
def __truediv__(self, other: Any) -> Self:
296-
ser, other = extract_native(self, other)
297-
return self._with_native(pc.divide(*cast_for_truediv(ser, other))) # type: ignore[type-var]
287+
return self._with_binary(lambda x, y: pc.divide(*cast_for_truediv(x, y)), other)
298288

299289
def __rtruediv__(self, other: Any) -> Self:
300-
ser, other = extract_native(self, other)
301-
return self._with_native(pc.divide(*cast_for_truediv(other, ser))) # type: ignore[type-var]
290+
return self._with_binary_right(
291+
lambda x, y: pc.divide(*cast_for_truediv(x, y)), other
292+
)
302293

303294
def __mod__(self, other: Any) -> Self:
295+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
304296
floor_div = (self // other).native
305297
ser, other = extract_native(self, other)
306298
res = pc.subtract(ser, pc.multiply(floor_div, other))
307-
return self._with_native(res)
299+
return self._with_native(res, preserve_broadcast=preserve_broadcast)
308300

309301
def __rmod__(self, other: Any) -> Self:
302+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
310303
floor_div = (other // self).native
311304
ser, other = extract_native(self, other)
312305
res = pc.subtract(other, pc.multiply(floor_div, ser))
313-
return self._with_native(res)
306+
return self._with_native(res, preserve_broadcast=preserve_broadcast)
314307

315308
def __invert__(self) -> Self:
316-
return self._with_native(pc.invert(self.native))
309+
return self._with_native(pc.invert(self.native), preserve_broadcast=True)
317310

318311
@property
319312
def _type(self) -> pa.DataType:

narwhals/_arrow/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,8 @@ def pad_series(
445445

446446

447447
def cast_to_comparable_string_types(
448-
*chunked_arrays: ChunkedArrayAny, separator: str
449-
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
448+
*chunked_arrays: ChunkedArrayAny | ScalarAny, separator: str
449+
) -> tuple[Iterator[ChunkedArrayAny | ScalarAny], ScalarAny]:
450450
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
451451
dtype = (
452452
pa.string() # (PyArrow default)

narwhals/_pandas_like/namespace.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ def __init__(self, implementation: Implementation, version: Version) -> None:
7272

7373
def coalesce(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
7474
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
75-
align = self._series._align_full_broadcast
76-
series = align(*(s for _expr in exprs for s in _expr(df)))
75+
series = (s for _expr in exprs for s in _expr(df))
7776
return [
7877
reduce(lambda x, y: x.fill_null(y, strategy=None, limit=None), series)
7978
]
@@ -127,10 +126,8 @@ def len(self) -> PandasLikeExpr:
127126
# --- horizontal ---
128127
def sum_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
129128
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
130-
align = self._series._align_full_broadcast
131129
it = chain.from_iterable(expr(df) for expr in exprs)
132-
series = align(*it)
133-
native_series = (s.fill_null(0, None, None) for s in series)
130+
native_series = (s.fill_null(0, None, None) for s in it)
134131
return [reduce(operator.add, native_series)]
135132

136133
return self._expr._from_callable(
@@ -146,7 +143,6 @@ def all_horizontal(
146143
self, *exprs: PandasLikeExpr, ignore_nulls: bool
147144
) -> PandasLikeExpr:
148145
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
149-
align = self._series._align_full_broadcast
150146
series = [s for _expr in exprs for s in _expr(df)]
151147
if not ignore_nulls and any(
152148
s.native.dtype == "object" and s.is_null().any() for s in series
@@ -164,7 +160,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
164160
if ignore_nulls
165161
else iter(series)
166162
)
167-
return [reduce(operator.and_, align(*it))]
163+
return [reduce(operator.and_, it)]
168164

169165
return self._expr._from_callable(
170166
func=func,
@@ -179,7 +175,6 @@ def any_horizontal(
179175
self, *exprs: PandasLikeExpr, ignore_nulls: bool
180176
) -> PandasLikeExpr:
181177
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
182-
align = self._series._align_full_broadcast
183178
series = [s for _expr in exprs for s in _expr(df)]
184179
if not ignore_nulls and any(
185180
s.native.dtype == "object" and s.is_null().any() for s in series
@@ -197,7 +192,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
197192
if ignore_nulls
198193
else iter(series)
199194
)
200-
return [reduce(operator.or_, align(*it))]
195+
return [reduce(operator.or_, it)]
201196

202197
return self._expr._from_callable(
203198
func=func,
@@ -211,11 +206,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
211206
def mean_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
212207
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
213208
expr_results = [s for _expr in exprs for s in _expr(df)]
214-
align = self._series._align_full_broadcast
215-
series = align(
216-
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
217-
)
218-
non_na = align(*(1 - s.is_null() for s in expr_results))
209+
series = (s.fill_null(0, strategy=None, limit=None) for s in expr_results)
210+
non_na = (1 - s.is_null() for s in expr_results)
219211
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
220212

221213
return self._expr._from_callable(
@@ -229,10 +221,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
229221

230222
def min_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
231223
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
232-
it = chain.from_iterable(expr(df) for expr in exprs)
233-
align = self._series._align_full_broadcast
234-
series = align(*it)
235-
224+
series = list(chain.from_iterable(expr(df) for expr in exprs))
236225
return [
237226
PandasLikeSeries(
238227
self.concat(
@@ -254,10 +243,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
254243

255244
def max_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
256245
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
257-
it = chain.from_iterable(expr(df) for expr in exprs)
258-
align = self._series._align_full_broadcast
259-
series = align(*it)
260-
246+
series = list(chain.from_iterable(expr(df) for expr in exprs))
261247
return [
262248
PandasLikeSeries(
263249
self.concat(

0 commit comments

Comments
 (0)