Skip to content

Commit e3fd995

Browse files
perf: Prefer Iterator > tuple > list, use native pyarrow.repeat, simplify nw.concat_str for DuckDB backend (#3190)
* A general preference to use iterators over tuples over lists * Use native `pyarrow.repeat` * Minimizing loops as well * Simplify (and optimize) `concat_str` for DuckDB backend --------- Co-authored-by: dangotbanned <[email protected]>
1 parent ebb2a40 commit e3fd995

File tree

22 files changed

+163
-138
lines changed

22 files changed

+163
-138
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from __future__ import annotations
22

33
from collections.abc import Collection, Iterator, Mapping, Sequence
4-
from functools import partial
54
from typing import TYPE_CHECKING, Any, Literal, cast, overload
65

76
import pyarrow as pa
87
import pyarrow.compute as pc
98

109
from narwhals._arrow.series import ArrowSeries
11-
from narwhals._arrow.utils import native_to_narwhals_dtype
10+
from narwhals._arrow.utils import concat_tables, native_to_narwhals_dtype, repeat
1211
from narwhals._compliant import EagerDataFrame
1312
from narwhals._expression_parsing import ExprKind
1413
from narwhals._utils import (
@@ -72,7 +71,6 @@
7271
"right outer",
7372
"full outer",
7473
]
75-
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
7674

7775

7876
class ArrowDataFrame(
@@ -790,34 +788,19 @@ def unpivot(
790788
variable_name: str,
791789
value_name: str,
792790
) -> Self:
793-
n_rows = len(self)
794-
index_ = [] if index is None else index
795-
on_ = [c for c in self.columns if c not in index_] if on is None else on
796-
concat = (
797-
partial(pa.concat_tables, promote_options="permissive")
798-
if self._backend_version >= (14, 0, 0)
799-
else pa.concat_tables
800-
)
801-
names = [*index_, variable_name, value_name]
802-
return self._with_native(
803-
concat(
804-
[
805-
pa.Table.from_arrays(
806-
[
807-
*(self.native.column(idx_col) for idx_col in index_),
808-
cast(
809-
"ChunkedArrayAny",
810-
pa.array([on_col] * n_rows, pa.string()),
811-
),
812-
self.native.column(on_col),
813-
],
814-
names=names,
815-
)
816-
for on_col in on_
817-
]
818-
)
819-
)
820791
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
821792
# upcast numeric to non-numeric (e.g. string) datatypes
793+
n = len(self)
794+
index = [] if index is None else list(index)
795+
on_ = (c for c in self.columns if c not in index) if on is None else iter(on)
796+
index_cols = self.native.select(index)
797+
column = self.native.column
798+
tables = (
799+
index_cols.append_column(variable_name, repeat(name, n)).append_column(
800+
value_name, column(name)
801+
)
802+
for name in on_
803+
)
804+
return self._with_native(concat_tables(tables, "permissive"))
822805

823806
pivot = not_implemented()

narwhals/_arrow/namespace.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
134134
int_64 = self._version.dtypes.Int64()
135135

136136
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
137-
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
137+
expr_results = tuple(chain.from_iterable(expr(df) for expr in exprs))
138138
align = self._series._align_full_broadcast
139139
series = align(
140140
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
@@ -154,7 +154,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
154154
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
155155
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
156156
align = self._series._align_full_broadcast
157-
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
157+
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
158158
init_series, *series = align(init_series, *series)
159159
native_series = reduce(
160160
pc.min_element_wise, [s.native for s in series], init_series.native
@@ -175,7 +175,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
175175
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
176176
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
177177
align = self._series._align_full_broadcast
178-
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
178+
init_series, *series = tuple(chain.from_iterable(expr(df) for expr in exprs))
179179
init_series, *series = align(init_series, *series)
180180
native_series = reduce(
181181
pc.max_element_wise, [s.native for s in series], init_series.native
@@ -200,7 +200,7 @@ def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
200200

201201
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
202202
names = list(chain.from_iterable(df.column_names for df in dfs))
203-
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
203+
arrays = tuple(chain.from_iterable(df.itercolumns() for df in dfs))
204204
return pa.Table.from_arrays(arrays, names=names)
205205

206206
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:

narwhals/_arrow/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"microsecond",
4646
"nanosecond",
4747
]
48+
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
4849

4950
ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any]
5051
ArrayAny: TypeAlias = pa.Array[Any]

narwhals/_arrow/utils.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pyarrow.compute as pc
88

99
from narwhals._compliant import EagerSeriesNamespace
10-
from narwhals._utils import Version, isinstance_or_issubclass
10+
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
1111

1212
if TYPE_CHECKING:
1313
from collections.abc import Iterable, Iterator, Mapping
@@ -21,7 +21,9 @@
2121
ArrayOrScalarT1,
2222
ArrayOrScalarT2,
2323
ChunkedArrayAny,
24+
Incomplete,
2425
NativeIntervalUnit,
26+
PromoteOptions,
2527
ScalarAny,
2628
)
2729
from narwhals._duration import IntervalUnit
@@ -57,6 +59,9 @@ def extract_regex(
5759
is_timestamp,
5860
)
5961

62+
BACKEND_VERSION = Implementation.PYARROW._backend_version()
63+
"""Static backend version for `pyarrow`."""
64+
6065
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
6166
"y": "year",
6267
"q": "quarter",
@@ -103,6 +108,17 @@ def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
103108
return pa.nulls(n, series.native.type)
104109

105110

111+
def repeat(
112+
value: PythonLiteral | ScalarAny, n: int, /, dtype: pa.DataType | None = None
113+
) -> ArrayAny:
114+
"""Create an Array instance whose slots are the given scalar.
115+
116+
*Optionally*, casting to `dtype` **before** repeating `n` times.
117+
"""
118+
lit_: Incomplete = lit
119+
return pa.repeat(lit_(value, type=dtype), n)
120+
121+
106122
def zeros(n: int, /) -> pa.Int64Array:
107123
return pa.repeat(0, n)
108124

@@ -423,10 +439,9 @@ def pad_series(
423439
offset_left = window_size // 2
424440
# subtract one if window_size is even
425441
offset_right = offset_left - (window_size % 2 == 0)
426-
pad_left = pa.array([None] * offset_left, type=series._type)
427-
pad_right = pa.array([None] * offset_right, type=series._type)
428-
concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
429-
return series._with_native(concat), offset_left + offset_right
442+
chunks = series.native.chunks
443+
arrays = nulls_like(offset_left, series), *chunks, nulls_like(offset_right, series)
444+
return series._with_native(pa.concat_arrays(arrays)), offset_left + offset_right
430445

431446

432447
def cast_to_comparable_string_types(
@@ -441,4 +456,26 @@ def cast_to_comparable_string_types(
441456
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
442457

443458

459+
if BACKEND_VERSION >= (14,):
460+
# https://arrow.apache.org/docs/14.0/python/generated/pyarrow.concat_tables.html
461+
_PROMOTE: Mapping[PromoteOptions, Mapping[str, Any]] = {
462+
"default": {"promote_options": "default"},
463+
"permissive": {"promote_options": "permissive"},
464+
"none": {"promote_options": "none"},
465+
}
466+
else: # pragma: no cover
467+
# https://arrow.apache.org/docs/13.0/python/generated/pyarrow.concat_tables.html
468+
_PROMOTE = {
469+
"default": {"promote": True},
470+
"permissive": {"promote": True},
471+
"none": {"promote": False},
472+
}
473+
474+
475+
def concat_tables(
476+
tables: Iterable[pa.Table], promote_options: PromoteOptions = "none"
477+
) -> pa.Table:
478+
return pa.concat_tables(tables, **_PROMOTE[promote_options])
479+
480+
444481
class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...

narwhals/_compliant/dataframe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
357357
# NOTE: Ignore intermittent [False Negative]
358358
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr"
359359
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
360-
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
360+
return tuple(
361+
chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs) # pyright: ignore[reportArgumentType]
362+
)
361363

362364
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
363365
"""Return list of raw columns.

narwhals/_compliant/expr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,13 @@ def _reuse_series_inner(
386386
series._from_scalar(method(series)) if returns_scalar else method(series)
387387
for series in self(df)
388388
]
389-
aliases = self._evaluate_aliases(df)
390-
if [s.name for s in out] != list(aliases): # pragma: no cover
389+
aliases, names = self._evaluate_aliases(df), (s.name for s in out)
390+
if any(
391+
alias != name for alias, name in zip_strict(aliases, names)
392+
): # pragma: no cover
391393
msg = (
392394
f"Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues\n"
393395
f"Expression aliases: {aliases}\n"
394-
f"Series names: {[s.name for s in out]}"
395396
)
396397
raise AssertionError(msg)
397398
return out

narwhals/_duckdb/dataframe.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Implementation,
2424
ValidateBackendVersion,
2525
Version,
26+
extend_bool,
2627
generate_temporary_column_name,
2728
not_implemented,
2829
parse_columns_to_drop,
@@ -393,27 +394,22 @@ def unique(
393394
if error := self._check_columns_exist(subset_):
394395
raise error
395396
tmp_name = generate_temporary_column_name(8, self.columns, prefix="row_index_")
396-
if order_by and keep == "last":
397-
descending = [True] * len(order_by)
398-
nulls_last = [True] * len(order_by)
399-
else:
400-
descending = None
401-
nulls_last = None
397+
flags = extend_bool(True, len(order_by)) if order_by and keep == "last" else None
402398
if keep == "none":
403399
expr = window_expression(
404400
F("count", StarExpression()),
405401
subset_,
406402
order_by or (),
407-
descending=descending,
408-
nulls_last=nulls_last,
403+
descending=flags,
404+
nulls_last=flags,
409405
)
410406
else:
411407
expr = window_expression(
412408
F("row_number"),
413409
subset_,
414410
order_by or (),
415-
descending=descending,
416-
nulls_last=nulls_last,
411+
descending=flags,
412+
nulls_last=flags,
417413
)
418414
return self._with_native(
419415
self.native.select(StarExpression(), expr.alias(tmp_name)).filter(
@@ -422,8 +418,7 @@ def unique(
422418
).drop([tmp_name], strict=False)
423419

424420
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
425-
if isinstance(descending, bool):
426-
descending = [descending] * len(by)
421+
descending = extend_bool(descending, len(by))
427422
if nulls_last:
428423
it = (
429424
col(name).nulls_last() if not desc else col(name).desc().nulls_last()
@@ -437,23 +432,23 @@ def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) ->
437432
return self._with_native(self.native.sort(*it))
438433

439434
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
440-
_df = self.native
435+
_rel = self.native
441436
by = list(by)
442437
if isinstance(reverse, bool):
443-
descending = [not reverse] * len(by)
438+
descending = extend_bool(not reverse, len(by))
444439
else:
445-
descending = [not rev for rev in reverse]
440+
descending = tuple(not rev for rev in reverse)
446441
expr = window_expression(
447442
F("row_number"),
448443
order_by=by,
449444
descending=descending,
450-
nulls_last=[True] * len(by),
445+
nulls_last=extend_bool(True, len(by)),
451446
)
452447
condition = expr <= lit(k)
453448
query = f"""
454-
SELECT *
455-
FROM _df
456-
QUALIFY {condition}
449+
SELECT *
450+
FROM _rel
451+
QUALIFY {condition}
457452
""" # noqa: S608
458453
return self._with_native(duckdb.sql(query))
459454

@@ -523,11 +518,11 @@ def unpivot(
523518
raise NotImplementedError(msg)
524519

525520
unpivot_on = join_column_names(*on_)
526-
rel = self.native # noqa: F841
521+
_rel = self.native
527522
# Replace with Python API once
528523
# https://github.com/duckdb/duckdb/discussions/16980 is addressed.
529524
query = f"""
530-
unpivot rel
525+
unpivot _rel
531526
on {unpivot_on}
532527
into
533528
name {col(variable_name)}
@@ -548,9 +543,9 @@ def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
548543
return self._with_native(self.native.select(expr, StarExpression()))
549544

550545
def sink_parquet(self, file: str | Path | BytesIO) -> None:
551-
df = self.native # noqa: F841
546+
_rel = self.native
552547
query = f"""
553-
COPY (SELECT * FROM df)
548+
COPY (SELECT * FROM _rel)
554549
TO '{file}'
555550
(FORMAT parquet)
556551
""" # noqa: S608

narwhals/_duckdb/expr.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from narwhals._expression_parsing import ExprKind, ExprMetadata
2424
from narwhals._sql.expr import SQLExpr
25-
from narwhals._utils import Implementation, Version
25+
from narwhals._utils import Implementation, Version, extend_bool
2626

2727
if TYPE_CHECKING:
2828
from collections.abc import Sequence
@@ -95,23 +95,21 @@ def _window_expression(
9595
nulls_last=nulls_last,
9696
)
9797

98-
def _first(self, expr: Expression, *order_by: str) -> Expression:
98+
def _first_last(
99+
self, function: str, expr: Expression, order_by: Sequence[str], /
100+
) -> Expression:
99101
# https://github.com/duckdb/duckdb/discussions/19252
102+
flags = extend_bool(False, len(order_by))
100103
order_by_sql = generate_order_by_sql(
101-
*order_by,
102-
descending=[False] * len(order_by),
103-
nulls_last=[False] * len(order_by),
104+
*order_by, descending=flags, nulls_last=flags
104105
)
105-
return sql_expression(f"first({expr} {order_by_sql})")
106+
return sql_expression(f"{function}({expr} {order_by_sql})")
107+
108+
def _first(self, expr: Expression, *order_by: str) -> Expression:
109+
return self._first_last("first", expr, order_by)
106110

107111
def _last(self, expr: Expression, *order_by: str) -> Expression:
108-
# https://github.com/duckdb/duckdb/discussions/19252
109-
order_by_sql = generate_order_by_sql(
110-
*order_by,
111-
descending=[False] * len(order_by),
112-
nulls_last=[False] * len(order_by),
113-
)
114-
return sql_expression(f"last({expr} {order_by_sql})")
112+
return self._first_last("last", expr, order_by)
115113

116114
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
117115
from narwhals._duckdb.namespace import DuckDBNamespace

narwhals/_duckdb/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
2828

2929
def agg(self, *exprs: DuckDBExpr) -> DuckDBLazyFrame:
30-
agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs)))
30+
agg_columns = tuple(chain(self._keys, self._evaluate_exprs(exprs)))
3131
return self.compliant._with_native(
3232
self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type]
3333
).rename(dict(zip(self._keys, self._output_key_names)))

0 commit comments

Comments
 (0)