Skip to content

Commit 75b33c4

Browse files
committed
feat: Spec-out all generic parts of EagerExpr
Wild how much of this is identical and easy to share #2149 (comment)
1 parent 7d48726 commit 75b33c4

File tree

4 files changed

+294
-10
lines changed

4 files changed

+294
-10
lines changed

narwhals/_compliant/expr.py

Lines changed: 282 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77
from typing import Any
88
from typing import Callable
99
from typing import Literal
10+
from typing import Mapping
1011
from typing import Protocol
1112
from typing import Sequence
1213

14+
from narwhals._compliant.namespace import CompliantNamespace
1315
from narwhals._compliant.typing import CompliantFrameT
1416
from narwhals._compliant.typing import CompliantLazyFrameT
1517
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
1618
from narwhals._compliant.typing import EagerDataFrameT
1719
from narwhals._compliant.typing import EagerSeriesT
1820
from narwhals._compliant.typing import NativeExprT_co
1921
from narwhals._expression_parsing import evaluate_output_names_and_aliases
22+
from narwhals.dtypes import DType
2023
from narwhals.utils import deprecated
2124
from narwhals.utils import not_implemented
2225
from narwhals.utils import unstable
@@ -37,6 +40,7 @@
3740
from typing_extensions import Self
3841

3942
from narwhals._compliant.namespace import CompliantNamespace
43+
from narwhals._compliant.namespace import EagerNamespace
4044
from narwhals._compliant.series import CompliantSeries
4145
from narwhals._expression_parsing import ExprKind
4246
from narwhals.dtypes import DType
@@ -236,6 +240,7 @@ class EagerExpr(
236240
CompliantExpr[EagerDataFrameT, EagerSeriesT],
237241
Protocol38[EagerDataFrameT, EagerSeriesT],
238242
):
243+
_call: Callable[[EagerDataFrameT], Sequence[EagerSeriesT]]
239244
_depth: int
240245
_function_name: str
241246
_evaluate_output_names: Any
@@ -259,6 +264,15 @@ def __init__(
259264
call_kwargs: dict[str, Any] | None = None,
260265
) -> None: ...
261266

267+
def __call__(self, df: EagerDataFrameT) -> Sequence[EagerSeriesT]:
268+
return self._call(df)
269+
270+
def __repr__(self) -> str: # pragma: no cover
271+
return f"{type(self).__name__}(depth={self._depth}, function_name={self._function_name})"
272+
273+
def __narwhals_namespace__(self) -> EagerNamespace[EagerDataFrameT, EagerSeriesT]: ...
274+
def __narwhals_expr__(self) -> None: ...
275+
262276
@classmethod
263277
def _from_callable(
264278
cls,
@@ -383,6 +397,34 @@ def _reuse_series_namespace_implementation(
383397
context=self,
384398
)
385399

400+
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
401+
# Mark the resulting Series with `_broadcast = True`.
402+
# Then, when extracting native objects, `extract_native` will
403+
# know what to do.
404+
def func(df: EagerDataFrameT) -> list[EagerSeriesT]:
405+
results = []
406+
for result in self(df):
407+
result._broadcast = True
408+
results.append(result)
409+
return results
410+
411+
return type(self)(
412+
func,
413+
depth=self._depth,
414+
function_name=self._function_name,
415+
evaluate_output_names=self._evaluate_output_names,
416+
alias_output_names=self._alias_output_names,
417+
backend_version=self._backend_version,
418+
implementation=self._implementation,
419+
version=self._version,
420+
call_kwargs=self._call_kwargs,
421+
)
422+
423+
def cast(
424+
self, dtype: DType | type[DType]
425+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
426+
return self._reuse_series_implementation("cast", dtype=dtype)
427+
386428
def __eq__(self, other: Self | Any) -> EagerExpr[EagerDataFrameT, EagerSeriesT]: # type: ignore[override]
387429
return self._reuse_series_implementation("__eq__", other=other)
388430

@@ -457,14 +499,249 @@ def __rmod__(self, other: Self | Any) -> EagerExpr[EagerDataFrameT, EagerSeriesT
457499
def __invert__(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
458500
return self._reuse_series_implementation("__invert__")
459501

460-
def cast(
461-
self, dtype: DType | type[DType]
462-
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
463-
return self._reuse_series_implementation("cast", dtype=dtype)
464-
502+
# Reductions
465503
def null_count(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
466504
return self._reuse_series_implementation("null_count", returns_scalar=True)
467505

506+
def n_unique(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
507+
return self._reuse_series_implementation("n_unique", returns_scalar=True)
508+
509+
def sum(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
510+
return self._reuse_series_implementation("sum", returns_scalar=True)
511+
512+
def mean(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
513+
return self._reuse_series_implementation("mean", returns_scalar=True)
514+
515+
def median(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
516+
return self._reuse_series_implementation("median", returns_scalar=True)
517+
518+
def std(self, *, ddof: int) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
519+
return self._reuse_series_implementation(
520+
"std", returns_scalar=True, call_kwargs={"ddof": ddof}
521+
)
522+
523+
def var(self, *, ddof: int) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
524+
return self._reuse_series_implementation(
525+
"var", returns_scalar=True, call_kwargs={"ddof": ddof}
526+
)
527+
528+
def skew(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
529+
return self._reuse_series_implementation("skew", returns_scalar=True)
530+
531+
def any(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
532+
return self._reuse_series_implementation("any", returns_scalar=True)
533+
534+
def all(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
535+
return self._reuse_series_implementation("all", returns_scalar=True)
536+
537+
def max(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
538+
return self._reuse_series_implementation("max", returns_scalar=True)
539+
540+
def mix(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
541+
return self._reuse_series_implementation("min", returns_scalar=True)
542+
543+
def arg_min(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
544+
return self._reuse_series_implementation("arg_min", returns_scalar=True)
545+
546+
def arg_max(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
547+
return self._reuse_series_implementation("arg_max", returns_scalar=True)
548+
549+
# Other
550+
551+
def clip(
552+
self, lower_bound: Any, upper_bound: Any
553+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
554+
return self._reuse_series_implementation(
555+
"clip", lower_bound=lower_bound, upper_bound=upper_bound
556+
)
557+
558+
def is_null(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
559+
return self._reuse_series_implementation("is_null")
560+
561+
def is_nan(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
562+
return self._reuse_series_implementation("is_nan")
563+
564+
def fill_null(
565+
self,
566+
value: Any | None,
567+
strategy: Literal["forward", "backward"] | None,
568+
limit: int | None,
569+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
570+
return self._reuse_series_implementation(
571+
"fill_null", value=value, strategy=strategy, limit=limit
572+
)
573+
574+
def is_in(self, other: Any) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
575+
return self._reuse_series_implementation("is_in", other="other")
576+
577+
def arg_true(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
578+
return self._reuse_series_implementation("arg_true")
579+
580+
# NOTE: `ewm_mean` not implemented `pyarrow`
581+
582+
def filter(self, *predicates: Self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
583+
plx = self.__narwhals_namespace__()
584+
other = plx.all_horizontal(*predicates)
585+
return self._reuse_series_implementation("filter", other=other)
586+
587+
def drop_nulls(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
588+
return self._reuse_series_implementation("drop_nulls")
589+
590+
def replace_strict(
591+
self,
592+
old: Sequence[Any] | Mapping[Any, Any],
593+
new: Sequence[Any],
594+
*,
595+
return_dtype: DType | type[DType] | None,
596+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
597+
return self._reuse_series_implementation(
598+
"replace_strict", old=old, new=new, return_dtype=return_dtype
599+
)
600+
601+
def sort(
602+
self, *, descending: bool, nulls_last: bool
603+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
604+
return self._reuse_series_implementation(
605+
"sort", descending=descending, nulls_last=nulls_last
606+
)
607+
608+
def abs(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
609+
return self._reuse_series_implementation("abs")
610+
611+
def unique(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
612+
return self._reuse_series_implementation("unique", maintain_order=False)
613+
614+
def diff(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
615+
return self._reuse_series_implementation("diff")
616+
617+
# NOTE: `shift` differs
618+
619+
def sample(
620+
self,
621+
n: int | None,
622+
*,
623+
fraction: float | None,
624+
with_replacement: bool,
625+
seed: int | None,
626+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
627+
return self._reuse_series_implementation(
628+
"sample", n=n, fraction=fraction, with_replacement=with_replacement, seed=seed
629+
)
630+
631+
def alias(self: Self, name: str) -> Self:
632+
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
633+
if len(names) != 1:
634+
msg = f"Expected function with single output, found output names: {names}"
635+
raise ValueError(msg)
636+
return [name]
637+
638+
# Define this one manually, so that we can
639+
# override `output_names` and not increase depth
640+
return type(self)(
641+
lambda df: [series.alias(name) for series in self(df)],
642+
depth=self._depth,
643+
function_name=self._function_name,
644+
evaluate_output_names=self._evaluate_output_names,
645+
alias_output_names=alias_output_names,
646+
backend_version=self._backend_version,
647+
implementation=self._implementation,
648+
version=self._version,
649+
call_kwargs=self._call_kwargs,
650+
)
651+
652+
# NOTE: `over` differs
653+
654+
def is_unique(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
655+
return self._reuse_series_implementation("is_unique")
656+
657+
def is_first_distinct(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
658+
return self._reuse_series_implementation("is_first_distinct")
659+
660+
def is_last_distinct(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
661+
return self._reuse_series_implementation("is_last_distinct")
662+
663+
def quantile(
664+
self,
665+
quantile: float,
666+
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
667+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
668+
return self._reuse_series_implementation(
669+
"quantile",
670+
quantile=quantile,
671+
interpolation=interpolation,
672+
returns_scalar=True,
673+
)
674+
675+
def head(self, n: int) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
676+
return self._reuse_series_implementation("head", n=n)
677+
678+
def tail(self, n: int) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
679+
return self._reuse_series_implementation("tail", n=n)
680+
681+
def round(self, decimals: int) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
682+
return self._reuse_series_implementation("round", decimals=decimals)
683+
684+
def len(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
685+
return self._reuse_series_implementation("len", returns_scalar=True)
686+
687+
def gather_every(
688+
self, n: int, offset: int
689+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
690+
return self._reuse_series_implementation("gather_every", n=n, offset=offset)
691+
692+
def mode(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
693+
return self._reuse_series_implementation("mode")
694+
695+
# NOTE: `map_batches` differs
696+
697+
def is_finite(self) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
698+
return self._reuse_series_implementation("is_finite")
699+
700+
# NOTE: `cum_(sum|count|min|max|prod)` differ
701+
702+
def rolling_mean(
703+
self, window_size: int, *, min_samples: int | None, center: bool
704+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
705+
return self._reuse_series_implementation(
706+
"rolling_mean",
707+
window_size=window_size,
708+
min_samples=min_samples,
709+
center=center,
710+
)
711+
712+
def rolling_std(
713+
self, window_size: int, *, min_samples: int | None, center: bool, ddof: int
714+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
715+
return self._reuse_series_implementation(
716+
"rolling_std",
717+
window_size=window_size,
718+
min_samples=min_samples,
719+
center=center,
720+
ddof=ddof,
721+
)
722+
723+
def rolling_sum(
724+
self, window_size: int, *, min_samples: int | None, center: bool
725+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
726+
return self._reuse_series_implementation(
727+
"rolling_sum", window_size=window_size, min_samples=min_samples, center=center
728+
)
729+
730+
def rolling_var(
731+
self, window_size: int, *, min_samples: int | None, center: bool, ddof: int
732+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]:
733+
return self._reuse_series_implementation(
734+
"rolling_var",
735+
window_size=window_size,
736+
min_samples=min_samples,
737+
center=center,
738+
ddof=ddof,
739+
)
740+
741+
# NOTE: `rank` differs
742+
743+
# NOTE: All namespaces differ
744+
468745

469746
# NOTE: See (https://github.com/narwhals-dev/narwhals/issues/2044#issuecomment-2674262833)
470747
class LazyExpr(

narwhals/_compliant/namespace.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from typing import Any
55
from typing import Protocol
66

7-
from narwhals._compliant.typing import CompliantDataFrameT
87
from narwhals._compliant.typing import CompliantFrameT
98
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
10-
from narwhals._compliant.typing import CompliantSeriesT_co
9+
from narwhals._compliant.typing import EagerDataFrameT
10+
from narwhals._compliant.typing import EagerSeriesT
1111

1212
if TYPE_CHECKING:
1313
from narwhals._compliant.expr import CompliantExpr
14+
from narwhals._compliant.expr import EagerExpr
1415
from narwhals._compliant.selectors import CompliantSelectorNamespace
1516
from narwhals.dtypes import DType
1617

@@ -29,6 +30,9 @@ def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
2930

3031

3132
class EagerNamespace(
32-
CompliantNamespace[CompliantDataFrameT, CompliantSeriesT_co],
33-
Protocol[CompliantDataFrameT, CompliantSeriesT_co],
34-
): ...
33+
CompliantNamespace[EagerDataFrameT, EagerSeriesT],
34+
Protocol[EagerDataFrameT, EagerSeriesT],
35+
):
36+
def all_horizontal(
37+
self, *exprs: EagerExpr[EagerDataFrameT, EagerSeriesT]
38+
) -> EagerExpr[EagerDataFrameT, EagerSeriesT]: ...

narwhals/_compliant/series.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class EagerSeries(CompliantSeries, Protocol[NativeSeriesT_co]):
3434
_implementation: Implementation
3535
_backend_version: tuple[int, ...]
3636
_version: Version
37+
_broadcast: bool
3738

3839
@property
3940
def native(self) -> NativeSeriesT_co: ...

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ omit = [
230230
'narwhals/_spark_like/*',
231231
# we don't run these in every environment
232232
'tests/ibis_test.py',
233+
# Remove after finishing eager sub-protocols
234+
'narwhals/_compliant/*',
233235
]
234236
exclude_also = [
235237
"if sys.version_info() <",

0 commit comments

Comments
 (0)