Skip to content

Commit 7198e95

Browse files
committed
refactor: Move is_eager_expr and improve safety
- Now we don't accept any arbitrary expression - It must derive the current type - Simplifies the typing - Really wasn't happy with what I had previously - Glad that is gone now - Added a doc, since the single output part is really the most important
1 parent cc53df6 commit 7198e95

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

narwhals/_compliant/dataframe.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@
2727
import pyarrow as pa
2828
from typing_extensions import Self
2929
from typing_extensions import TypeAlias
30-
from typing_extensions import TypeIs
3130

32-
from narwhals._compliant.expr import EagerExpr
3331
from narwhals.dtypes import DType
3432
from narwhals.typing import SizeUnit
3533
from narwhals.typing import _2DArray
@@ -178,17 +176,16 @@ class EagerDataFrame(
178176
CompliantDataFrame[EagerSeriesT, EagerExprT_contra],
179177
Protocol[EagerSeriesT, EagerExprT_contra],
180178
):
181-
def _maybe_evaluate_expr(self, expr: EagerExprT_contra | T, /) -> EagerSeriesT | T:
182-
if is_eager_expr(expr):
183-
result: Sequence[EagerSeriesT] = expr(self)
184-
if len(result) > 1:
185-
msg = (
186-
"Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) "
187-
"are not supported in this context"
188-
)
189-
raise ValueError(msg)
190-
return result[0]
191-
return expr
179+
def _evaluate_expr(self, expr: EagerExprT_contra, /) -> EagerSeriesT:
180+
"""Evaluate `expr` and ensure it has a **single** output."""
181+
result: Sequence[EagerSeriesT] = expr(self)
182+
if len(result) > 1:
183+
msg = (
184+
"Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) "
185+
"are not supported in this context"
186+
)
187+
raise ValueError(msg)
188+
return result[0]
192189

193190
def _evaluate_into_exprs(self, *exprs: EagerExprT_contra) -> Sequence[EagerSeriesT]:
194191
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs))
@@ -209,11 +206,3 @@ def _evaluate_into_expr(self, expr: EagerExprT_contra, /) -> Sequence[EagerSerie
209206
msg = f"Safety assertion failed, expected {aliases}, got {result}"
210207
raise AssertionError(msg)
211208
return result
212-
213-
214-
# NOTE: `mypy` is requiring the gymnastics here and is very fragile
215-
# DON'T CHANGE THIS or `EagerDataFrame._maybe_evaluate_expr`
216-
def is_eager_expr(
217-
obj: EagerExpr[Any, EagerSeriesT] | Any,
218-
) -> TypeIs[EagerExpr[Any, EagerSeriesT]]:
219-
return hasattr(obj, "__narwhals_expr__")

narwhals/_compliant/expr.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from typing import Mapping
5252

5353
from typing_extensions import Self
54+
from typing_extensions import TypeIs
5455

5556
from narwhals._compliant.namespace import CompliantNamespace
5657
from narwhals._compliant.namespace import EagerNamespace
@@ -390,6 +391,10 @@ def _reuse_series_extra_kwargs(
390391
) -> dict[str, Any]:
391392
return {}
392393

394+
@classmethod
395+
def _is_expr(cls, obj: Self | Any) -> TypeIs[Self]:
396+
return isinstance(obj, cls)
397+
393398
def _reuse_series_inner(
394399
self,
395400
df: EagerDataFrameT,
@@ -402,8 +407,8 @@ def _reuse_series_inner(
402407
kwargs = {
403408
**call_kwargs,
404409
**{
405-
arg_name: df._maybe_evaluate_expr(arg_value)
406-
for arg_name, arg_value in expressifiable_args.items()
410+
name: df._evaluate_expr(value) if self._is_expr(value) else value
411+
for name, value in expressifiable_args.items()
407412
},
408413
}
409414
method = methodcaller(

0 commit comments

Comments
 (0)