22
33from typing import TYPE_CHECKING
44
5- from narwhals ._plan .common import Immutable
5+ from narwhals ._plan .common import Immutable , is_expr
6+ from narwhals ._plan .dummy import DummyExpr
67from narwhals ._plan .expr_parsing import parse_into_expr_ir
78
89if TYPE_CHECKING :
910 from narwhals ._plan .common import ExprIR , IntoExpr , Seq
10- from narwhals ._plan .dummy import DummyExpr
1111 from narwhals ._plan .expr import Ternary
1212
1313
@@ -24,7 +24,7 @@ def _from_expr(expr: DummyExpr, /) -> When:
2424 return When (condition = expr ._ir )
2525
2626
27- class Then (Immutable ):
27+ class Then (Immutable , DummyExpr ):
2828 __slots__ = ("condition" , "statement" )
2929
3030 condition : ExprIR
@@ -37,9 +37,23 @@ def when(self, condition: IntoExpr, /) -> ChainedWhen:
3737 )
3838
3939 def otherwise (self , statement : IntoExpr , / ) -> DummyExpr :
40- return ternary_expr (
41- self .condition , self .condition , parse_into_expr_ir (statement )
42- ).to_narwhals ()
40+ return self ._from_ir (self ._otherwise (statement ))
41+
42+ def _otherwise (self , statement : IntoExpr = None , / ) -> ExprIR :
43+ return ternary_expr (self .condition , self .statement , parse_into_expr_ir (statement ))
44+
45+ @property
46+ def _ir (self ) -> ExprIR : # type: ignore[override]
47+ return self ._otherwise ()
48+
49+ @classmethod
50+ def _from_ir (cls , ir : ExprIR , / ) -> DummyExpr : # type: ignore[override]
51+ return DummyExpr ._from_ir (ir )
52+
53+ def __eq__ (self , value : object ) -> DummyExpr | bool : # type: ignore[override]
54+ if is_expr (value ):
55+ return super (DummyExpr , self ).__eq__ (value )
56+ return super ().__eq__ (value )
4357
4458
4559class ChainedWhen (Immutable ):
@@ -55,7 +69,7 @@ def then(self, statement: IntoExpr, /) -> ChainedThen:
5569 )
5670
5771
58- class ChainedThen (Immutable ):
72+ class ChainedThen (Immutable , DummyExpr ):
5973 """https://github.com/pola-rs/polars/blob/b9dd8cdbd6e6ec8373110536955ed5940b9460ec/crates/polars-plan/src/dsl/arity.rs#L89-L130."""
6074
6175 __slots__ = ("conditions" , "statements" )
@@ -70,12 +84,28 @@ def when(self, condition: IntoExpr, /) -> ChainedWhen:
7084 )
7185
7286 def otherwise (self , statement : IntoExpr , / ) -> DummyExpr :
87+ return self ._from_ir (self ._otherwise (statement ))
88+
89+ def _otherwise (self , statement : IntoExpr = None , / ) -> ExprIR :
7390 otherwise = parse_into_expr_ir (statement )
7491 it_conditions = reversed (self .conditions )
7592 it_statements = reversed (self .statements )
7693 for e in it_conditions :
7794 otherwise = ternary_expr (e , next (it_statements ), otherwise )
78- return otherwise .to_narwhals ()
95+ return otherwise
96+
97+ @property
98+ def _ir (self ) -> ExprIR : # type: ignore[override]
99+ return self ._otherwise ()
100+
101+ @classmethod
102+ def _from_ir (cls , ir : ExprIR , / ) -> DummyExpr : # type: ignore[override]
103+ return DummyExpr ._from_ir (ir )
104+
105+ def __eq__ (self , value : object ) -> DummyExpr | bool : # type: ignore[override]
106+ if is_expr (value ):
107+ return super (DummyExpr , self ).__eq__ (value )
108+ return super ().__eq__ (value )
79109
80110
81111def ternary_expr (predicate : ExprIR , truthy : ExprIR , falsy : ExprIR , / ) -> Ternary :
0 commit comments