|
2 | 2 |
|
3 | 3 | import operator |
4 | 4 | from functools import reduce |
| 5 | +from itertools import chain |
5 | 6 | from typing import TYPE_CHECKING, Any |
6 | 7 |
|
7 | 8 | from narwhals._expression_parsing import ( |
|
19 | 20 | ) |
20 | 21 | from narwhals._sql.namespace import SQLNamespace |
21 | 22 | from narwhals._sql.when_then import SQLThen, SQLWhen |
22 | | -from narwhals._utils import zip_strict |
23 | 23 |
|
24 | 24 | if TYPE_CHECKING: |
25 | 25 | from collections.abc import Iterable |
@@ -173,37 +173,13 @@ def concat_str( |
173 | 173 | self, *exprs: SparkLikeExpr, separator: str, ignore_nulls: bool |
174 | 174 | ) -> SparkLikeExpr: |
175 | 175 | def func(df: SparkLikeLazyFrame) -> list[Column]: |
176 | | - cols = [s for _expr in exprs for s in _expr(df)] |
177 | | - cols_casted = [s.cast(df._native_dtypes.StringType()) for s in cols] |
178 | | - null_mask = [df._F.isnull(s) for s in cols] |
| 176 | + F = self._F |
| 177 | + cols = tuple(chain.from_iterable(e(df) for e in exprs)) |
| 178 | + result = F.concat_ws(separator, *cols) |
179 | 179 |
|
180 | 180 | if not ignore_nulls: |
181 | | - null_mask_result = reduce(operator.or_, null_mask) |
182 | | - result = df._F.when( |
183 | | - ~null_mask_result, |
184 | | - reduce( |
185 | | - lambda x, y: df._F.format_string(f"%s{separator}%s", x, y), |
186 | | - cols_casted, |
187 | | - ), |
188 | | - ).otherwise(df._F.lit(None)) |
189 | | - else: |
190 | | - init_value, *values = [ |
191 | | - df._F.when(~nm, col).otherwise(df._F.lit("")) |
192 | | - for col, nm in zip_strict(cols_casted, null_mask) |
193 | | - ] |
194 | | - |
195 | | - separators = ( |
196 | | - df._F.when(nm, df._F.lit("")).otherwise(df._F.lit(separator)) |
197 | | - for nm in null_mask[:-1] |
198 | | - ) |
199 | | - result = reduce( |
200 | | - lambda x, y: df._F.format_string("%s%s", x, y), |
201 | | - ( |
202 | | - df._F.format_string("%s%s", s, v) |
203 | | - for s, v in zip_strict(separators, values) |
204 | | - ), |
205 | | - init_value, |
206 | | - ) |
| 181 | + null_mask = reduce(operator.or_, (F.isnull(s) for s in cols)) |
| 182 | + result = F.when(~null_mask, result).otherwise(F.lit(None)) |
207 | 183 |
|
208 | 184 | return [result] |
209 | 185 |
|
|
0 commit comments