Skip to content

Commit 35fb578

Browse files
committed
feat: Implement chained when-then-otherwise 🥳
Related (#668 (comment)) - This would be how we should model it in *actual narwhals* - Almost identical to the `rust` version - See this in particular (https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130)
1 parent 72c33ce commit 35fb578

File tree

3 files changed

+133
-2
lines changed

3 files changed

+133
-2
lines changed

narwhals/_plan/demo.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
expr_parsing as parse,
1010
functions as F, # noqa: N812
1111
)
12-
from narwhals._plan.common import ExprIR, IntoExpr, is_non_nested_literal
12+
from narwhals._plan.common import ExprIR, IntoExpr, is_expr, is_non_nested_literal
1313
from narwhals._plan.dummy import DummySeries
1414
from narwhals._plan.expr import All, Column, Columns, IndexColumns, Len, Nth
1515
from narwhals._plan.literal import ScalarLiteral, SeriesLiteral
1616
from narwhals._plan.strings import ConcatHorizontal
17+
from narwhals._plan.when_then import When
1718
from narwhals.dtypes import DType
1819
from narwhals.exceptions import OrderDependentExprError
1920
from narwhals.utils import Version, flatten
@@ -131,6 +132,32 @@ def concat_str(
131132
)
132133

133134

135+
def when(*predicates: IntoExpr | t.Iterable[IntoExpr]) -> When:
136+
"""Start a `when-then-otherwise` expression.
137+
138+
Examples:
139+
>>> from narwhals._plan import demo as nwd
140+
141+
>>> when_then_many = (
142+
... nwd.when(nwd.col("x") == "a")
143+
... .then(1)
144+
... .when(nwd.col("x") == "b")
145+
... .then(2)
146+
... .when(nwd.col("x") == "c")
147+
... .then(3)
148+
... .otherwise(4)
149+
... )
150+
>>> when_then_many
151+
Narwhals DummyExpr (main):
152+
.when([(col('x')) == (lit(str: a))]).then(lit(int: 1)).otherwise(.when([(col('x')) == (lit(str: b))]).then(lit(int: 2)).otherwise(.when([(col('x')) == (lit(str: c))]).then(lit(int: 3)).otherwise(lit(int: 4))))
153+
"""
154+
if builtins.len(predicates) == 1 and is_expr(predicates[0]):
155+
expr = predicates[0]
156+
else:
157+
expr = all_horizontal(*predicates)
158+
return When._from_expr(expr)
159+
160+
134161
def _is_order_enforcing_previous(obj: t.Any) -> TypeIs[SortBy]:
135162
"""In theory, we could add other nodes to this check."""
136163
from narwhals._plan.expr import SortBy

narwhals/_plan/expr.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import typing as t
1414

1515
from narwhals._plan.aggregation import Agg, OrderableAgg
16-
from narwhals._plan.common import ExprIR, SelectorIR
16+
from narwhals._plan.common import ExprIR, SelectorIR, _field_str
1717
from narwhals._plan.name import KeepName, RenameAlias
1818
from narwhals._plan.typing import (
1919
FunctionT,
@@ -475,3 +475,23 @@ class Ternary(ExprIR):
475475
476476
Deferring this for now.
477477
"""
478+
479+
__slots__ = ("falsy", "predicate", "truthy")
480+
481+
predicate: ExprIR
482+
truthy: ExprIR
483+
falsy: ExprIR
484+
485+
def __str__(self) -> str:
486+
# NOTE: Default slot ordering made it difficult to read
487+
fields = (
488+
_field_str("predicate", self.predicate),
489+
_field_str("truthy", self.truthy),
490+
_field_str("falsy", self.falsy),
491+
)
492+
return f"{type(self).__name__}({', '.join(fields)})"
493+
494+
def __repr__(self) -> str:
495+
return (
496+
f".when({self.predicate!r}).then({self.truthy!r}).otherwise({self.falsy!r})"
497+
)

narwhals/_plan/when_then.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from narwhals._plan.common import Immutable
6+
from narwhals._plan.expr_parsing import parse_into_expr_ir
7+
8+
if TYPE_CHECKING:
9+
from narwhals._plan.common import ExprIR, IntoExpr, Seq
10+
from narwhals._plan.dummy import DummyExpr
11+
from narwhals._plan.expr import Ternary
12+
13+
14+
class When(Immutable):
15+
__slots__ = ("condition",)
16+
17+
condition: ExprIR
18+
19+
def then(self, expr: IntoExpr, /) -> Then:
20+
return Then(condition=self.condition, statement=parse_into_expr_ir(expr))
21+
22+
@staticmethod
23+
def _from_expr(expr: DummyExpr, /) -> When:
24+
return When(condition=expr._ir)
25+
26+
27+
class Then(Immutable):
28+
__slots__ = ("condition", "statement")
29+
30+
condition: ExprIR
31+
statement: ExprIR
32+
33+
def when(self, condition: IntoExpr, /) -> ChainedWhen:
34+
return ChainedWhen(
35+
conditions=(self.condition, parse_into_expr_ir(condition)),
36+
statements=(self.statement,),
37+
)
38+
39+
def otherwise(self, statement: IntoExpr, /) -> DummyExpr:
40+
return ternary_expr(
41+
self.condition, self.condition, parse_into_expr_ir(statement)
42+
).to_narwhals()
43+
44+
45+
class ChainedWhen(Immutable):
46+
__slots__ = ("conditions", "statements")
47+
48+
conditions: Seq[ExprIR]
49+
statements: Seq[ExprIR]
50+
51+
def then(self, statement: IntoExpr, /) -> ChainedThen:
52+
return ChainedThen(
53+
conditions=self.conditions,
54+
statements=(*self.statements, parse_into_expr_ir(statement)),
55+
)
56+
57+
58+
class ChainedThen(Immutable):
59+
"""https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130."""
60+
61+
__slots__ = ("conditions", "statements")
62+
63+
conditions: Seq[ExprIR]
64+
statements: Seq[ExprIR]
65+
66+
def when(self, condition: IntoExpr, /) -> ChainedWhen:
67+
return ChainedWhen(
68+
conditions=(*self.conditions, parse_into_expr_ir(condition)),
69+
statements=self.statements,
70+
)
71+
72+
def otherwise(self, statement: IntoExpr, /) -> DummyExpr:
73+
otherwise = parse_into_expr_ir(statement)
74+
it_conditions = reversed(self.conditions)
75+
it_statements = reversed(self.statements)
76+
for e in it_conditions:
77+
otherwise = ternary_expr(e, next(it_statements), otherwise)
78+
return otherwise.to_narwhals()
79+
80+
81+
def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary:
82+
from narwhals._plan.expr import Ternary
83+
84+
return Ternary(predicate=predicate, truthy=truthy, falsy=falsy)

0 commit comments

Comments
 (0)