Skip to content

Commit 563076d

Browse files
committed
feat: Support optional .otherwise(...)
Somewhat of typing nightmare but gets the job done for now
1 parent 35fb578 commit 563076d

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

narwhals/_plan/demo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def when(*predicates: IntoExpr | t.Iterable[IntoExpr]) -> When:
150150
>>> when_then_many
151151
Narwhals DummyExpr (main):
152152
.when([(col('x')) == (lit(str: a))]).then(lit(int: 1)).otherwise(.when([(col('x')) == (lit(str: b))]).then(lit(int: 2)).otherwise(.when([(col('x')) == (lit(str: c))]).then(lit(int: 3)).otherwise(lit(int: 4))))
153+
>>>
154+
>>> nwd.when(nwd.col("y") == "b").then(1)
155+
Narwhals DummyExpr (main):
156+
.when([(col('y')) == (lit(str: b))]).then(lit(int: 1)).otherwise(lit(null))
153157
"""
154158
if builtins.len(predicates) == 1 and is_expr(predicates[0]):
155159
expr = predicates[0]

narwhals/_plan/when_then.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from typing import TYPE_CHECKING
44

5-
from narwhals._plan.common import Immutable
5+
from narwhals._plan.common import Immutable, is_expr
6+
from narwhals._plan.dummy import DummyExpr
67
from narwhals._plan.expr_parsing import parse_into_expr_ir
78

89
if TYPE_CHECKING:
910
from narwhals._plan.common import ExprIR, IntoExpr, Seq
10-
from narwhals._plan.dummy import DummyExpr
1111
from narwhals._plan.expr import Ternary
1212

1313

@@ -24,7 +24,7 @@ def _from_expr(expr: DummyExpr, /) -> When:
2424
return When(condition=expr._ir)
2525

2626

27-
class Then(Immutable):
27+
class Then(Immutable, DummyExpr):
2828
__slots__ = ("condition", "statement")
2929

3030
condition: ExprIR
@@ -37,9 +37,23 @@ def when(self, condition: IntoExpr, /) -> ChainedWhen:
3737
)
3838

3939
def otherwise(self, statement: IntoExpr, /) -> DummyExpr:
40-
return ternary_expr(
41-
self.condition, self.condition, parse_into_expr_ir(statement)
42-
).to_narwhals()
40+
return self._from_ir(self._otherwise(statement))
41+
42+
def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR:
43+
return ternary_expr(self.condition, self.statement, parse_into_expr_ir(statement))
44+
45+
@property
46+
def _ir(self) -> ExprIR: # type: ignore[override]
47+
return self._otherwise()
48+
49+
@classmethod
50+
def _from_ir(cls, ir: ExprIR, /) -> DummyExpr: # type: ignore[override]
51+
return DummyExpr._from_ir(ir)
52+
53+
def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override]
54+
if is_expr(value):
55+
return super(DummyExpr, self).__eq__(value)
56+
return super().__eq__(value)
4357

4458

4559
class ChainedWhen(Immutable):
@@ -55,7 +69,7 @@ def then(self, statement: IntoExpr, /) -> ChainedThen:
5569
)
5670

5771

58-
class ChainedThen(Immutable):
72+
class ChainedThen(Immutable, DummyExpr):
5973
"""https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130."""
6074

6175
__slots__ = ("conditions", "statements")
@@ -70,12 +84,28 @@ def when(self, condition: IntoExpr, /) -> ChainedWhen:
7084
)
7185

7286
def otherwise(self, statement: IntoExpr, /) -> DummyExpr:
87+
return self._from_ir(self._otherwise(statement))
88+
89+
def _otherwise(self, statement: IntoExpr = None, /) -> ExprIR:
7390
otherwise = parse_into_expr_ir(statement)
7491
it_conditions = reversed(self.conditions)
7592
it_statements = reversed(self.statements)
7693
for e in it_conditions:
7794
otherwise = ternary_expr(e, next(it_statements), otherwise)
78-
return otherwise.to_narwhals()
95+
return otherwise
96+
97+
@property
98+
def _ir(self) -> ExprIR: # type: ignore[override]
99+
return self._otherwise()
100+
101+
@classmethod
102+
def _from_ir(cls, ir: ExprIR, /) -> DummyExpr: # type: ignore[override]
103+
return DummyExpr._from_ir(ir)
104+
105+
def __eq__(self, value: object) -> DummyExpr | bool: # type: ignore[override]
106+
if is_expr(value):
107+
return super(DummyExpr, self).__eq__(value)
108+
return super().__eq__(value)
79109

80110

81111
def ternary_expr(predicate: ExprIR, truthy: ExprIR, falsy: ExprIR, /) -> Ternary:

0 commit comments

Comments
 (0)