Skip to content

Commit 8888e2c

Browse files
authored
perf: avoid full broadcast for otherwise_value in when/then/otherwise (#2098)
1 parent 490d029 commit 8888e2c

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

narwhals/_arrow/namespace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,10 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
424424
if isinstance(self._otherwise_value, ArrowExpr):
425425
otherwise_series = self._otherwise_value(df)[0]
426426
else:
427-
otherwise_series = plx._create_series_from_scalar(
428-
self._otherwise_value, reference_series=condition.alias("literal")
427+
native_result = pc.if_else(
428+
condition_native, value_series_native, self._otherwise_value
429429
)
430-
otherwise_series._broadcast = True
430+
return [value_series._from_native_series(native_result)]
431431

432432
otherwise_series_native = extract_dataframe_comparand(
433433
len(df), otherwise_series, self._backend_version

narwhals/_dask/namespace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]:
363363
if isinstance(self._otherwise_value, DaskExpr):
364364
otherwise_value = self._otherwise_value(df)[0]
365365
else:
366-
otherwise_value = self._otherwise_value
366+
return [then_series.where(condition, self._otherwise_value)]
367367
(otherwise_series,) = align_series_full_broadcast(df, otherwise_value)
368368
validate_comparand(condition, otherwise_series)
369369
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]

narwhals/_pandas_like/namespace.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,10 @@ def __call__(self: Self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
445445
if isinstance(self._otherwise_value, PandasLikeExpr):
446446
otherwise_series = self._otherwise_value(df)[0]
447447
else:
448-
otherwise_series = plx._create_series_from_scalar(
449-
self._otherwise_value, reference_series=condition.alias("literal")
448+
native_result = value_series_native.where(
449+
condition_native, self._otherwise_value
450450
)
451-
otherwise_series._broadcast = True
451+
return [value_series._from_native_series(native_result)]
452452
otherwise_series_native = extract_dataframe_comparand(
453453
df._native_frame.index, otherwise_series
454454
)

narwhals/_spark_like/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from narwhals.utils import Version
2020

2121

22+
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
2223
def native_to_narwhals_dtype(
2324
dtype: pyspark_types.DataType,
2425
version: Version,

0 commit comments

Comments
 (0)