Skip to content

Commit 63c85c8

Browse files
authored
Merge branch 'main' into oh-nodes
2 parents bf544d4 + 849326f commit 63c85c8

21 files changed

+367
-204
lines changed

.pre-commit-config.yaml

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ repos:
2222
args: [--ignore-words-list=ser, --ignore-words-list=RightT]
2323
exclude: ^docs/api-completeness.md$
2424
- repo: https://github.com/pycqa/flake8
25-
rev: '7.3.0' # todo: remove once https://github.com/astral-sh/ruff/issues/458 is addressed
25+
rev: '7.3.0'
2626
hooks:
2727
- id: flake8
28-
additional_dependencies: [darglint==1.8.1, Flake8-pyproject]
28+
# TODO(unassigned): replace with ruff once https://github.com/astral-sh/ruff/issues/458 is addressed.
29+
name: darglint
30+
alias: darglint
2931
entry: flake8 --select DAR --ignore DAR101,DAR201,DAR401,DAR402
3032
exclude: |
3133
(?x)^(
@@ -39,6 +41,15 @@ repos:
3941
# Soft-deprecated, so less crucial to document so carefully
4042
narwhals/stable/v1/.*
4143
)$
44+
# Keep in sync with `flake8-typing-imports` so the same venv is reused.
45+
additional_dependencies: [darglint==1.8.1, flake8-typing-imports==1.17.0]
46+
- id: flake8
47+
# TODO(unassigned): replace with ruff once https://github.com/astral-sh/ruff/issues/2302 is addressed.
48+
name: flake8-typing-imports
49+
alias: flake8-typing-imports
50+
entry: flake8 --select TYP --min-python-version=3.9.0
51+
# Keep in sync with `darglint` so the same venv is reused.
52+
additional_dependencies: [darglint==1.8.1, flake8-typing-imports==1.17.0]
4253
- repo: local
4354
hooks:
4455
- id: check-api-reference
@@ -52,6 +63,11 @@ repos:
5263
pass_filenames: false
5364
entry: python -m utils.sort_api_reference
5465
language: python
66+
- id: check-slotted-classes
67+
name: check-slotted-classes
68+
pass_filenames: false
69+
entry: python -m utils.check_slotted_classes
70+
language: python
5571
- id: imports-are-banned
5672
name: import are banned (use `get_pandas` instead of `import pandas`)
5773
entry: python utils/import_check.py

narwhals/_arrow/namespace.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
if TYPE_CHECKING:
2424
from collections.abc import Iterator, Sequence
2525

26-
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
26+
from narwhals._arrow.typing import (
27+
ArrayOrScalar,
28+
ChunkedArrayAny,
29+
Incomplete,
30+
ScalarAny,
31+
)
2732
from narwhals._compliant.typing import ScalarKwargs
2833
from narwhals._utils import Version
2934
from narwhals.typing import IntoDType, NonNestedLiteral
@@ -49,6 +54,11 @@ def _series(self) -> type[ArrowSeries]:
4954
def __init__(self, *, version: Version) -> None:
5055
self._version = version
5156

57+
def extract_native(
58+
self, *series: ArrowSeries
59+
) -> Iterator[ChunkedArrayAny | ScalarAny]:
60+
return (s.native[0] if s._broadcast else s.native for s in series)
61+
5262
def len(self) -> ArrowExpr:
5363
# coverage bug? this is definitely hit
5464
return self._expr( # pragma: no cover
@@ -83,10 +93,9 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
8393
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
8494
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
8595
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
86-
align = self._series._align_full_broadcast
8796
if ignore_nulls:
8897
series = (s.fill_null(True, None, None) for s in series)
89-
return [reduce(operator.and_, align(*series))]
98+
return [reduce(operator.and_, series)]
9099

91100
return self._expr._from_callable(
92101
func=func,
@@ -100,10 +109,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
100109
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
101110
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
102111
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
103-
align = self._series._align_full_broadcast
104112
if ignore_nulls:
105113
series = (s.fill_null(False, None, None) for s in series)
106-
return [reduce(operator.or_, align(*series))]
114+
return [reduce(operator.or_, series)]
107115

108116
return self._expr._from_callable(
109117
func=func,
@@ -118,8 +126,7 @@ def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
118126
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
119127
it = chain.from_iterable(expr(df) for expr in exprs)
120128
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
121-
align = self._series._align_full_broadcast
122-
return [reduce(operator.add, align(*series))]
129+
return [reduce(operator.add, series)]
123130

124131
return self._expr._from_callable(
125132
func=func,
@@ -135,11 +142,8 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
135142

136143
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
137144
expr_results = tuple(chain.from_iterable(expr(df) for expr in exprs))
138-
align = self._series._align_full_broadcast
139-
series = align(
140-
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
141-
)
142-
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
145+
series = [s.fill_null(0, strategy=None, limit=None) for s in expr_results]
146+
non_na = [1 - s.is_null().cast(int_64) for s in expr_results]
143147
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
144148

145149
return self._expr._from_callable(
@@ -153,9 +157,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
153157

154158
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
155159
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
156-
align = self._series._align_full_broadcast
157160
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
158-
init_series, *series = align(init_series, *series)
159161
native_series = reduce(
160162
pc.min_element_wise, [s.native for s in series], init_series.native
161163
)
@@ -174,9 +176,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
174176

175177
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
176178
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
177-
align = self._series._align_full_broadcast
178179
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
179-
init_series, *series = align(init_series, *series)
180180
native_series = reduce(
181181
pc.max_element_wise, [s.native for s in series], init_series.native
182182
)
@@ -227,16 +227,13 @@ def concat_str(
227227
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
228228
) -> ArrowExpr:
229229
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
230-
align = self._series._align_full_broadcast
231-
compliant_series_list = align(
232-
*(chain.from_iterable(expr(df) for expr in exprs))
233-
)
234-
name = compliant_series_list[0].name
230+
series = list(chain.from_iterable(expr(df) for expr in exprs))
231+
name = series[0].name
235232
null_handling: Literal["skip", "emit_null"] = (
236233
"skip" if ignore_nulls else "emit_null"
237234
)
238235
it, separator_scalar = cast_to_comparable_string_types(
239-
*(s.native for s in compliant_series_list), separator=separator
236+
*self.extract_native(*series), separator=separator
240237
)
241238
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
242239
# Reality: `str` is fine

narwhals/_arrow/series.py

Lines changed: 40 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Literal, cast, overload
3+
from typing import TYPE_CHECKING, Any, Callable, Literal, cast, overload
44

55
import pyarrow as pa
66
import pyarrow.compute as pc
@@ -148,6 +148,16 @@ def _with_native(
148148
result._broadcast = self._broadcast
149149
return result
150150

151+
def _with_binary(self, op: Callable[..., ArrayOrScalar], other: Any) -> Self:
152+
ser, other_native = extract_native(self, other)
153+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
154+
return self._with_native(
155+
op(ser, other_native), preserve_broadcast=preserve_broadcast
156+
).alias(self.name)
157+
158+
def _with_binary_right(self, op: Callable[..., ArrayOrScalar], other: Any) -> Self:
159+
return self._with_binary(lambda x, y: op(y, x), other).alias(self.name)
160+
151161
@classmethod
152162
def from_iterable(
153163
cls,
@@ -214,106 +224,89 @@ def __narwhals_namespace__(self) -> ArrowNamespace:
214224
return ArrowNamespace(version=self._version)
215225

216226
def __eq__(self, other: object) -> Self: # type: ignore[override]
217-
other = cast("PythonLiteral | ArrowSeries | None", other)
218-
ser, rhs = extract_native(self, other)
219-
return self._with_native(pc.equal(ser, rhs))
227+
return self._with_binary(pc.equal, other)
220228

221229
def __ne__(self, other: object) -> Self: # type: ignore[override]
222-
other = cast("PythonLiteral | ArrowSeries | None", other)
223-
ser, rhs = extract_native(self, other)
224-
return self._with_native(pc.not_equal(ser, rhs))
230+
return self._with_binary(pc.not_equal, other)
225231

226232
def __ge__(self, other: Any) -> Self:
227-
ser, other = extract_native(self, other)
228-
return self._with_native(pc.greater_equal(ser, other))
233+
return self._with_binary(pc.greater_equal, other)
229234

230235
def __gt__(self, other: Any) -> Self:
231-
ser, other = extract_native(self, other)
232-
return self._with_native(pc.greater(ser, other))
236+
return self._with_binary(pc.greater, other)
233237

234238
def __le__(self, other: Any) -> Self:
235-
ser, other = extract_native(self, other)
236-
return self._with_native(pc.less_equal(ser, other))
239+
return self._with_binary(pc.less_equal, other)
237240

238241
def __lt__(self, other: Any) -> Self:
239-
ser, other = extract_native(self, other)
240-
return self._with_native(pc.less(ser, other))
242+
return self._with_binary(pc.less, other)
241243

242244
def __and__(self, other: Any) -> Self:
243-
ser, other = extract_native(self, other)
244-
return self._with_native(pc.and_kleene(ser, other)) # type: ignore[arg-type]
245+
return self._with_binary(pc.and_kleene, other)
245246

246247
def __rand__(self, other: Any) -> Self:
247-
ser, other = extract_native(self, other)
248-
return self._with_native(pc.and_kleene(other, ser)) # type: ignore[arg-type]
248+
return self._with_binary_right(pc.and_kleene, other)
249249

250250
def __or__(self, other: Any) -> Self:
251-
ser, other = extract_native(self, other)
252-
return self._with_native(pc.or_kleene(ser, other)) # type: ignore[arg-type]
251+
return self._with_binary_right(pc.or_kleene, other)
253252

254253
def __ror__(self, other: Any) -> Self:
255-
ser, other = extract_native(self, other)
256-
return self._with_native(pc.or_kleene(other, ser)) # type: ignore[arg-type]
254+
return self._with_binary_right(pc.or_kleene, other)
257255

258256
def __add__(self, other: Any) -> Self:
259-
ser, other = extract_native(self, other)
260-
return self._with_native(pc.add(ser, other))
257+
return self._with_binary(pc.add, other)
261258

262259
def __radd__(self, other: Any) -> Self:
263-
return self + other
260+
return self._with_binary_right(pc.add, other)
264261

265262
def __sub__(self, other: Any) -> Self:
266-
ser, other = extract_native(self, other)
267-
return self._with_native(pc.subtract(ser, other))
263+
return self._with_binary(pc.subtract, other)
268264

269265
def __rsub__(self, other: Any) -> Self:
270-
return (self - other) * (-1)
266+
return self._with_binary_right(pc.subtract, other)
271267

272268
def __mul__(self, other: Any) -> Self:
273-
ser, other = extract_native(self, other)
274-
return self._with_native(pc.multiply(ser, other))
269+
return self._with_binary(pc.multiply, other)
275270

276271
def __rmul__(self, other: Any) -> Self:
277-
return self * other
272+
return self._with_binary_right(pc.multiply, other)
278273

279274
def __pow__(self, other: Any) -> Self:
280-
ser, other = extract_native(self, other)
281-
return self._with_native(pc.power(ser, other))
275+
return self._with_binary(pc.power, other)
282276

283277
def __rpow__(self, other: Any) -> Self:
284-
ser, other = extract_native(self, other)
285-
return self._with_native(pc.power(other, ser))
278+
return self._with_binary_right(pc.power, other)
286279

287280
def __floordiv__(self, other: Any) -> Self:
288-
ser, other = extract_native(self, other)
289-
return self._with_native(floordiv_compat(ser, other))
281+
return self._with_binary(floordiv_compat, other)
290282

291283
def __rfloordiv__(self, other: Any) -> Self:
292-
ser, other = extract_native(self, other)
293-
return self._with_native(floordiv_compat(other, ser))
284+
return self._with_binary_right(floordiv_compat, other)
294285

295286
def __truediv__(self, other: Any) -> Self:
296-
ser, other = extract_native(self, other)
297-
return self._with_native(pc.divide(*cast_for_truediv(ser, other))) # type: ignore[type-var]
287+
return self._with_binary(lambda x, y: pc.divide(*cast_for_truediv(x, y)), other)
298288

299289
def __rtruediv__(self, other: Any) -> Self:
300-
ser, other = extract_native(self, other)
301-
return self._with_native(pc.divide(*cast_for_truediv(other, ser))) # type: ignore[type-var]
290+
return self._with_binary_right(
291+
lambda x, y: pc.divide(*cast_for_truediv(x, y)), other
292+
)
302293

303294
def __mod__(self, other: Any) -> Self:
295+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
304296
floor_div = (self // other).native
305297
ser, other = extract_native(self, other)
306298
res = pc.subtract(ser, pc.multiply(floor_div, other))
307-
return self._with_native(res)
299+
return self._with_native(res, preserve_broadcast=preserve_broadcast)
308300

309301
def __rmod__(self, other: Any) -> Self:
302+
preserve_broadcast = self._broadcast and getattr(other, "_broadcast", True)
310303
floor_div = (other // self).native
311304
ser, other = extract_native(self, other)
312305
res = pc.subtract(other, pc.multiply(floor_div, ser))
313-
return self._with_native(res)
306+
return self._with_native(res, preserve_broadcast=preserve_broadcast)
314307

315308
def __invert__(self) -> Self:
316-
return self._with_native(pc.invert(self.native))
309+
return self._with_native(pc.invert(self.native), preserve_broadcast=True)
317310

318311
@property
319312
def _type(self) -> pa.DataType:

narwhals/_arrow/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,8 @@ def pad_series(
445445

446446

447447
def cast_to_comparable_string_types(
448-
*chunked_arrays: ChunkedArrayAny, separator: str
449-
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
448+
*chunked_arrays: ChunkedArrayAny | ScalarAny, separator: str
449+
) -> tuple[Iterator[ChunkedArrayAny | ScalarAny], ScalarAny]:
450450
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
451451
dtype = (
452452
pa.string() # (PyArrow default)

0 commit comments

Comments
 (0)