Skip to content

Commit 9240530

Browse files
chore: Simplify spark-like and ibis concat_str (#3240)
* spark-like * ibis * is it pyarrow? --------- Co-authored-by: Marco Edward Gorelli <[email protected]>
1 parent 2cba4a0 commit 9240530

File tree

2 files changed

+13
-35
lines changed

2 files changed

+13
-35
lines changed

narwhals/_ibis/namespace.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,14 @@ def func(df: IbisLazyFrame) -> list[ir.Value]:
8383
cols = chain.from_iterable(expr(df) for expr in exprs)
8484
cols_casted = [s.cast("string") for s in cols]
8585

86-
if not ignore_nulls:
87-
result = cols_casted[0]
88-
for col in cols_casted[1:]:
89-
result = result + separator + col
90-
else:
86+
if ignore_nulls:
9187
result = lit(separator).join(cols_casted)
88+
else:
89+
result = reduce(
90+
lambda acc, col: acc.concat(separator, col),
91+
cols_casted[1:],
92+
cols_casted[0],
93+
)
9294

9395
return [result]
9496

narwhals/_spark_like/namespace.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import operator
44
from functools import reduce
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Any
67

78
from narwhals._expression_parsing import (
@@ -19,7 +20,6 @@
1920
)
2021
from narwhals._sql.namespace import SQLNamespace
2122
from narwhals._sql.when_then import SQLThen, SQLWhen
22-
from narwhals._utils import zip_strict
2323

2424
if TYPE_CHECKING:
2525
from collections.abc import Iterable
@@ -173,37 +173,13 @@ def concat_str(
173173
self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool
174174
) -> SparkLikeExpr:
175175
def func(df: SparkLikeLazyFrame) -> list[Column]:
176-
cols = [s for _expr in exprs for s in _expr(df)]
177-
cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols]
178-
null_mask = [df._F.isnull(s) for s in cols]
176+
F = self._F
177+
cols = tuple(chain.from_iterable(e(df) for e in exprs))
178+
result = F.concat_ws(separator, *cols)
179179

180180
if not ignore_nulls:
181-
null_mask_result = reduce(operator.or_, null_mask)
182-
result = df._F.when(
183-
~null_mask_result,
184-
reduce(
185-
lambda x, y: df._F.format_string(f"%s{separator}%s", x, y),
186-
cols_casted,
187-
),
188-
).otherwise(df._F.lit(None))
189-
else:
190-
init_value, *values = [
191-
df._F.when(~nm, col).otherwise(df._F.lit(""))
192-
for col, nm in zip_strict(cols_casted, null_mask)
193-
]
194-
195-
separators = (
196-
df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator))
197-
for nm in null_mask[:-1]
198-
)
199-
result = reduce(
200-
lambda x, y: df._F.format_string("%s%s", x, y),
201-
(
202-
df._F.format_string("%s%s", s, v)
203-
for s, v in zip_strict(separators, values)
204-
),
205-
init_value,
206-
)
181+
null_mask = reduce(operator.or_, (F.isnull(s) for s in cols))
182+
result = F.when(~null_mask, result).otherwise(F.lit(None))
207183

208184
return [result]
209185

0 commit comments

Comments
 (0)