Skip to content

Commit a194a1d

Browse files
committed
simpler, top-level solution
1 parent 8b52650 commit a194a1d

File tree

11 files changed

+164
-454
lines changed

11 files changed

+164
-454
lines changed

narwhals/_arrow/namespace.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -247,23 +247,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
247247

248248
def _if_then_else(
249249
self,
250-
when: ChunkedArrayAny | list[ChunkedArrayAny],
251-
then: ChunkedArrayAny | list[ChunkedArrayAny],
250+
when: ChunkedArrayAny,
251+
then: ChunkedArrayAny,
252252
otherwise: ChunkedArrayAny | None = None,
253253
) -> ChunkedArrayAny:
254-
if not isinstance(when, list):
255-
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
256-
return pc.if_else(when, then, otherwise)
257-
258-
conditions = when
259-
values = then
260-
261-
if otherwise is None:
262-
result = pa.nulls(len(conditions[-1]), values[-1].type)
263-
else:
264-
result = otherwise
265-
266-
for cond, val in reversed(list(zip(conditions, values))):
267-
result = pc.if_else(cond, val, result)
268-
269-
return result
254+
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
255+
return pc.if_else(when, then, otherwise)

narwhals/_compliant/namespace.py

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -201,69 +201,40 @@ def _dataframe(self) -> type[EagerDataFrameT]: ...
201201
def _series(self) -> type[EagerSeriesT_co]: ...
202202
def _if_then_else(
203203
self,
204-
when: NativeSeriesT | list[NativeSeriesT],
205-
then: NativeSeriesT | list[NativeSeriesT],
204+
when: NativeSeriesT,
205+
then: NativeSeriesT,
206206
otherwise: NativeSeriesT | None = None,
207207
) -> NativeSeriesT: ...
208-
def _when_then_simple(
209-
self, df: EagerDataFrameT, predicate: EagerExprT, then: EagerExprT
210-
) -> Sequence[EagerSeriesT_co]:
211-
predicate_s = df._evaluate_single_output_expr(predicate)
212-
then_s = df._evaluate_single_output_expr(then)
213-
predicate_s, then_s = predicate_s._align_full_broadcast(predicate_s, then_s)
214-
result = self._if_then_else(predicate_s.native, then_s.native)
215-
return [then_s._with_native(result)]
216-
217-
def _when_then_chained(
218-
self, df: EagerDataFrameT, args: tuple[EagerExprT, ...]
219-
) -> Sequence[EagerSeriesT_co]:
220-
evaluated = [df._evaluate_single_output_expr(arg) for arg in args]
221-
*pairs_list, otherwise_s = (
222-
evaluated if len(evaluated) % 2 == 1 else (*evaluated, None)
223-
)
224-
conditions = pairs_list[::2]
225-
values = pairs_list[1::2]
226-
227-
all_series = (
228-
[*conditions, *values, otherwise_s]
229-
if otherwise_s is not None
230-
else [*conditions, *values]
231-
)
232-
aligned = conditions[0]._align_full_broadcast(*all_series)
233-
234-
num_conditions = len(conditions)
235-
aligned_conditions = aligned[:num_conditions]
236-
aligned_values = aligned[num_conditions : num_conditions * 2]
237-
aligned_otherwise = aligned[-1] if otherwise_s is not None else None
238-
239-
if len(conditions) == 1:
240-
result = self._if_then_else(
241-
aligned_conditions[0].native,
242-
aligned_values[0].native,
243-
aligned_otherwise.native if aligned_otherwise is not None else None,
244-
)
245-
else:
246-
result = self._if_then_else(
247-
[c.native for c in aligned_conditions],
248-
[v.native for v in aligned_values],
249-
aligned_otherwise.native if aligned_otherwise is not None else None,
250-
)
251-
252-
return [values[0]._with_native(result)]
253-
254-
def when_then(self, *args: EagerExprT) -> EagerExprT:
208+
def when_then(
209+
self, predicate: EagerExprT, then: EagerExprT, otherwise: EagerExprT | None = None
210+
) -> EagerExprT:
255211
def func(df: EagerDataFrameT) -> Sequence[EagerSeriesT_co]:
256-
if len(args) == 2:
257-
return self._when_then_simple(df, args[0], args[1])
258-
return self._when_then_chained(df, args)
212+
predicate_s = df._evaluate_single_output_expr(predicate)
213+
align = predicate_s._align_full_broadcast
214+
215+
then_s = df._evaluate_single_output_expr(then)
216+
if otherwise is None:
217+
predicate_s, then_s = align(predicate_s, then_s)
218+
result = self._if_then_else(predicate_s.native, then_s.native)
219+
220+
if otherwise is None:
221+
predicate_s, then_s = align(predicate_s, then_s)
222+
result = self._if_then_else(predicate_s.native, then_s.native)
223+
else:
224+
otherwise_s = df._evaluate_single_output_expr(otherwise)
225+
predicate_s, then_s, otherwise_s = align(predicate_s, then_s, otherwise_s)
226+
result = self._if_then_else(
227+
predicate_s.native, then_s.native, otherwise_s.native
228+
)
229+
return [then_s._with_native(result)]
259230

260231
return self._expr._from_callable(
261232
func=func,
262233
evaluate_output_names=getattr(
263-
args[1], "_evaluate_output_names", lambda _df: ["literal"]
234+
then, "_evaluate_output_names", lambda _df: ["literal"]
264235
),
265-
alias_output_names=getattr(args[1], "_alias_output_names", None),
266-
context=args[0],
236+
alias_output_names=getattr(then, "_alias_output_names", None),
237+
context=predicate,
267238
)
268239

269240
def is_native(self, obj: Any, /) -> TypeIs[NativeFrameT | NativeSeriesT]:

narwhals/_dask/namespace.py

Lines changed: 37 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -273,114 +273,50 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
273273
version=self._version,
274274
)
275275

276-
def _when_then_simple(
277-
self,
278-
df: DaskLazyFrame,
279-
predicate: DaskExpr,
280-
then: DaskExpr,
281-
otherwise: DaskExpr | None,
282-
) -> list[dx.Series]:
283-
then_value = df._evaluate_single_output_expr(then)
284-
otherwise_value = (
285-
df._evaluate_single_output_expr(otherwise)
286-
if otherwise is not None
287-
else otherwise
288-
)
289-
290-
condition = df._evaluate_single_output_expr(predicate)
291-
if all(
292-
x._metadata.is_scalar_like
293-
for x in (
294-
(predicate, then) if otherwise is None else (predicate, then, otherwise)
295-
)
296-
):
297-
new_df = df._with_native(condition.to_frame())
298-
condition = df._evaluate_single_output_expr(predicate.broadcast())
299-
df = new_df
300-
301-
if otherwise is None:
302-
(condition, then_series) = align_series_full_broadcast(
303-
df, condition, then_value
304-
)
305-
validate_comparand(condition, then_series)
306-
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
307-
(condition, then_series, otherwise_series) = align_series_full_broadcast(
308-
df, condition, then_value, otherwise_value
309-
)
310-
validate_comparand(condition, then_series)
311-
validate_comparand(condition, otherwise_series)
312-
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
313-
314-
def _when_then_chained(
315-
self, df: DaskLazyFrame, args: tuple[DaskExpr, ...]
316-
) -> list[dx.Series]:
317-
import numpy as np # ignore-banned-import
318-
319-
evaluated = [df._evaluate_single_output_expr(arg) for arg in args]
320-
*pairs_list, otherwise_value = (
321-
evaluated if len(evaluated) % 2 == 1 else (*evaluated, None)
322-
)
323-
conditions = pairs_list[::2]
324-
values = pairs_list[1::2]
325-
326-
all_series = (
327-
[*conditions, *values, otherwise_value]
328-
if otherwise_value is not None
329-
else [*conditions, *values]
330-
)
331-
aligned = align_series_full_broadcast(df, *all_series)
332-
333-
num_conditions = len(conditions)
334-
aligned_conditions = aligned[:num_conditions]
335-
aligned_values = aligned[num_conditions : num_conditions * 2]
336-
aligned_otherwise = aligned[-1] if otherwise_value is not None else None
337-
338-
for cond, val in zip_strict(aligned_conditions, aligned_values):
339-
validate_comparand(cond, val)
340-
if aligned_otherwise is not None:
341-
validate_comparand(aligned_conditions[0], aligned_otherwise)
342-
343-
def apply_select(*partition_series: pd.Series) -> pd.Series:
344-
num_conds = len(aligned_conditions)
345-
cond_parts = partition_series[:num_conds]
346-
val_parts = partition_series[num_conds : num_conds * 2]
347-
otherwise_part = (
348-
partition_series[-1] if aligned_otherwise is not None else None
349-
)
350-
result = np.select(
351-
list(cond_parts),
352-
list(val_parts),
353-
default=otherwise_part if otherwise_part is not None else np.nan,
354-
)
355-
return pd.Series(result, index=cond_parts[0].index)
356-
357-
map_args = [*aligned_conditions, *aligned_values]
358-
if aligned_otherwise is not None:
359-
map_args.append(aligned_otherwise)
360-
361-
return [
362-
dd.map_partitions(
363-
apply_select,
364-
*map_args,
365-
meta=(aligned_values[0].name, aligned_values[0].dtype),
276+
def when_then(
277+
self, predicate: DaskExpr, then: DaskExpr, otherwise: DaskExpr | None = None
278+
) -> DaskExpr:
279+
def func(df: DaskLazyFrame) -> list[dx.Series]:
280+
then_value = df._evaluate_single_output_expr(then)
281+
otherwise_value = (
282+
df._evaluate_single_output_expr(otherwise)
283+
if otherwise is not None
284+
else otherwise
366285
)
367-
]
368286

369-
def when_then(self, *args: DaskExpr) -> DaskExpr:
370-
def func(df: DaskLazyFrame) -> list[dx.Series]:
371-
if len(args) <= 3:
372-
return self._when_then_simple(
373-
df, args[0], args[1], args[2] if len(args) == 3 else None
287+
condition = df._evaluate_single_output_expr(predicate)
288+
# re-evaluate DataFrame if the condition aggregates to force
289+
# then/otherwise to be evaluated against the aggregated frame
290+
if all(
291+
x._metadata.is_scalar_like
292+
for x in (
293+
(predicate, then)
294+
if otherwise is None
295+
else (predicate, then, otherwise)
374296
)
375-
return self._when_then_chained(df, args)
376-
377-
then_arg = args[1]
297+
):
298+
new_df = df._with_native(condition.to_frame())
299+
condition = df._evaluate_single_output_expr(predicate.broadcast())
300+
df = new_df
301+
302+
if otherwise is None:
303+
(condition, then_series) = align_series_full_broadcast(
304+
df, condition, then_value
305+
)
306+
validate_comparand(condition, then_series)
307+
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
308+
(condition, then_series, otherwise_series) = align_series_full_broadcast(
309+
df, condition, then_value, otherwise_value
310+
)
311+
validate_comparand(condition, then_series)
312+
validate_comparand(condition, otherwise_series)
313+
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
378314

379315
return self._expr(
380316
call=func,
381317
evaluate_output_names=getattr(
382-
then_arg, "_evaluate_output_names", lambda _df: ["literal"]
318+
then, "_evaluate_output_names", lambda _df: ["literal"]
383319
),
384-
alias_output_names=getattr(then_arg, "_alias_output_names", None),
320+
alias_output_names=getattr(then, "_alias_output_names", None),
385321
version=self._version,
386322
)

narwhals/_duckdb/namespace.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,14 @@ def _lit(self, value: Any) -> Expression:
6767
return lit(value)
6868

6969
def _when(
70-
self, condition: Expression, value: Expression, *args: Expression
70+
self,
71+
condition: Expression,
72+
value: Expression,
73+
otherwise: Expression | None = None,
7174
) -> Expression:
72-
if not args:
75+
if otherwise is None:
7376
return when(condition, value)
74-
75-
if len(args) == 1:
76-
otherwise = args[0]
77-
return when(condition, value).otherwise(otherwise)
78-
79-
all_exprs = [condition, value, *args]
80-
81-
has_otherwise = len(all_exprs) % 2 == 1
82-
83-
if has_otherwise:
84-
*pairs, otherwise_expr = all_exprs
85-
else:
86-
pairs = all_exprs
87-
otherwise_expr = None
88-
89-
result = when(pairs[0], pairs[1])
90-
91-
for cond, val in zip(pairs[2::2], pairs[3::2]):
92-
result = result.when(cond, val)
93-
94-
if has_otherwise:
95-
result = result.otherwise(otherwise_expr)
96-
97-
return result
77+
return when(condition, value).otherwise(otherwise)
9878

9979
def _coalesce(self, *exprs: Expression) -> Expression:
10080
return CoalesceOperator(*exprs)

narwhals/_ibis/namespace.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,29 +54,12 @@ def _function(self, name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
5454
def _lit(self, value: Any) -> ir.Value:
5555
return lit(value)
5656

57-
def _when(self, condition: ir.Value, value: ir.Value, *args: ir.Value) -> ir.Value:
58-
if not args:
57+
def _when(
58+
self, condition: ir.Value, value: ir.Value, otherwise: ir.Expr | None = None
59+
) -> ir.Value:
60+
if otherwise is None:
5961
return ibis.cases((condition, value))
60-
61-
if len(args) == 1:
62-
otherwise = args[0]
63-
return ibis.cases((condition, value), else_=otherwise) # pragma: no cover
64-
65-
all_exprs = [condition, value, *args]
66-
67-
has_otherwise = len(all_exprs) % 2 == 1
68-
69-
if has_otherwise:
70-
*pairs, otherwise_expr = all_exprs
71-
else:
72-
pairs = all_exprs
73-
otherwise_expr = None
74-
75-
tuples = list(zip(pairs[::2], pairs[1::2]))
76-
77-
if has_otherwise:
78-
return ibis.cases(*tuples, else_=otherwise_expr)
79-
return ibis.cases(*tuples)
62+
return ibis.cases((condition, value), else_=otherwise) # pragma: no cover
8063

8164
def _coalesce(self, *exprs: ir.Value) -> ir.Value:
8265
return ibis.coalesce(*exprs)

narwhals/_pandas_like/namespace.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -366,24 +366,12 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
366366

367367
def _if_then_else(
368368
self,
369-
when: NativeSeriesT | list[NativeSeriesT],
370-
then: NativeSeriesT | list[NativeSeriesT],
369+
when: NativeSeriesT,
370+
then: NativeSeriesT,
371371
otherwise: NativeSeriesT | None = None,
372372
) -> NativeSeriesT:
373-
if not isinstance(when, list):
374-
where: Incomplete = then.where
375-
return where(when) if otherwise is None else where(when, otherwise)
376-
377-
import numpy as np
378-
379-
condlist = when
380-
choicelist = then
381-
382-
default = otherwise if otherwise is not None else np.nan
383-
384-
result = np.select(condlist, choicelist, default=default)
385-
386-
return then[0].__class__(result, index=then[0].index)
373+
where: Incomplete = then.where
374+
return where(when) if otherwise is None else where(when, otherwise)
387375

388376

389377
class _NativeConcat(Protocol[NativeDataFrameT, NativeSeriesT]):

0 commit comments

Comments
 (0)