From 4eb66a075be0a4fa6a6b6f1c4e2db91f7ecf77a2 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 13 Mar 2025 19:46:21 +0000 Subject: [PATCH 1/6] feat(typing): Fill out `CompliantNamespace` protocol --- narwhals/_arrow/namespace.py | 3 +-- narwhals/_compliant/expr.py | 2 +- narwhals/_compliant/namespace.py | 43 ++++++++++++++++++++++++------- narwhals/_compliant/typing.py | 1 + narwhals/_dask/namespace.py | 6 ++++- narwhals/_duckdb/namespace.py | 10 +++++-- narwhals/_expression_parsing.py | 5 ++-- narwhals/_spark_like/namespace.py | 6 ++++- 8 files changed, 56 insertions(+), 20 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 22e2e56184..8a349f3c1c 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -40,7 +40,6 @@ from typing_extensions import TypeAlias from narwhals._arrow.typing import Incomplete - from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType from narwhals.utils import Version @@ -166,7 +165,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: context=self, ) - def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr: + def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: dtypes = import_dtypes_module(self._version) def func(df: ArrowDataFrame) -> list[ArrowSeries]: diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 30e9bfa0a7..1fd1ee5c64 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -89,7 +89,7 @@ def __call__( def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, - ) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... + ) -> CompliantNamespace[CompliantFrameT, Self]: ... def is_null(self) -> Self: ... def abs(self) -> Self: ... def all(self) -> Self: ... diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 688f2770c2..fe2e30112b 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -2,36 +2,59 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Container +from typing import Iterable +from typing import Literal from typing import Protocol +from narwhals._compliant.typing import CompliantExprT from narwhals._compliant.typing import CompliantFrameT -from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co from narwhals._compliant.typing import EagerDataFrameT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT_co from narwhals.utils import deprecated if TYPE_CHECKING: - from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.selectors import CompliantSelectorNamespace from narwhals.dtypes import DType __all__ = ["CompliantNamespace", "EagerNamespace"] -class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): - def col( - self, *column_names: str - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... - def lit( - self, value: Any, dtype: DType | None - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... +class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): + def col(self, *column_names: str) -> CompliantExprT: ... + def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ... + def nth(self, *column_indices: int) -> CompliantExprT: ... + def len(self) -> CompliantExprT: ... + def all(self) -> CompliantExprT: ... + def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def concat( + self, + items: Iterable[CompliantFrameT], + *, + how: Literal["horizontal", "vertical", "diagonal"], + ) -> CompliantFrameT: ... + def when(self, predicate: CompliantExprT) -> Any: ... + def concat_str( + self, + *exprs: CompliantExprT, + separator: str, + ignore_nulls: bool, + ) -> CompliantExprT: ... @property def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... + @property + def _expr(self) -> type[CompliantExprT]: ... class EagerNamespace( - CompliantNamespace[EagerDataFrameT, EagerSeriesT_co], + CompliantNamespace[EagerDataFrameT, EagerExprT], Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], ): @property diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 2513097a50..ba970e4dbb 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -42,6 +42,7 @@ CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]") CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame") IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" +CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]") EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]") EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]") diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 10744d6bb0..511828c78d 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -42,13 +42,17 @@ import dask_expr as dx -class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): +class DaskNamespace(CompliantNamespace[DaskLazyFrame, "DaskExpr"]): _implementation: Implementation = Implementation.DASK @property def selectors(self: Self) -> DaskSelectorNamespace: return DaskSelectorNamespace(self) + @property + def _expr(self) -> type[DaskExpr]: + return DaskExpr + def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version ) -> None: diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index f245c74e27..0e2f490b22 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -7,6 +7,7 @@ from typing import Any from typing import Callable from typing import Container +from typing import Iterable from typing import Literal from typing import Sequence @@ -38,7 +39,7 @@ from narwhals.utils import Version -class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]): +class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "DuckDBExpr"]): _implementation: Implementation = Implementation.DUCKDB def __init__( @@ -51,6 +52,10 @@ def __init__( def selectors(self: Self) -> DuckDBSelectorNamespace: return DuckDBSelectorNamespace(self) + @property + def _expr(self) -> type[DuckDBExpr]: + return DuckDBExpr + def all(self: Self) -> DuckDBExpr: return DuckDBExpr.from_column_names( get_column_names, @@ -61,7 +66,7 @@ def all(self: Self) -> DuckDBExpr: def concat( self: Self, - items: Sequence[DuckDBLazyFrame], + items: Iterable[DuckDBLazyFrame], *, how: Literal["horizontal", "vertical", "diagonal"], ) -> DuckDBLazyFrame: @@ -71,6 +76,7 @@ def concat( if how == "diagonal": msg = "Not implemented yet" raise NotImplementedError(msg) + items = list(items) first = items[0] schema = first.schema if how == "vertical" and not all(x.schema == schema for x in items[1:]): diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index f5d091c4eb..a8cb3518f9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -26,7 +26,6 @@ from narwhals._compliant import CompliantExpr from narwhals._compliant import CompliantFrameT from narwhals._compliant import CompliantNamespace - from narwhals._compliant import CompliantSeriesOrNativeExprT_co from narwhals.expr import Expr from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame @@ -91,11 +90,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: def extract_compliant( - plx: CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co], + plx: CompliantNamespace[CompliantFrameT, Any], other: Any, *, str_as_lit: bool, -) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | object: +) -> CompliantExpr[CompliantFrameT, Any] | object: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index ed414e196c..6c8802b820 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -32,7 +32,7 @@ from narwhals.utils import Version -class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): +class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "SparkLikeExpr"]): def __init__( self: Self, *, @@ -48,6 +48,10 @@ def __init__( def selectors(self: Self) -> SparkLikeSelectorNamespace: return SparkLikeSelectorNamespace(self) + @property + def _expr(self) -> type[SparkLikeExpr]: + return SparkLikeExpr + def all(self: Self) -> SparkLikeExpr: return SparkLikeExpr.from_column_names( get_column_names, From 8319d0f2a2774fa9e5a6a84fe7999a7e2f48fb6e Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 13 Mar 2025 19:48:02 +0000 Subject: [PATCH 2/6] refactor(typing): Simplify `EagerNamespace` - Only needs to be the extra stuff - `_create_compliant_series` is removed in #2196 --- narwhals/_compliant/namespace.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index fe2e30112b..6ad8967f74 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -57,11 +57,8 @@ class EagerNamespace( CompliantNamespace[EagerDataFrameT, EagerExprT], Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], ): - @property - def _expr(self) -> type[EagerExprT]: ... @property def _series(self) -> type[EagerSeriesT_co]: ... - def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ... @deprecated( "Internally used for `numpy.ndarray` -> `CompliantSeries`\n" From 70170da64ebcbf021c9b567c1a85f182f77dbc42 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 13 Mar 2025 21:22:10 +0000 Subject: [PATCH 3/6] coverage https://github.com/narwhals-dev/narwhals/actions/runs/13844590933/job/38740030977?pr=2202 --- narwhals/_dask/namespace.py | 26 +++++++++++++------------- narwhals/_duckdb/namespace.py | 24 ++++++++++++------------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 511828c78d..86c7a4b013 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -60,7 +60,7 @@ def __init__( self._version = version def all(self: Self) -> DaskExpr: - return DaskExpr.from_column_names( + return self._expr.from_column_names( get_column_names, function_name="all", backend_version=self._backend_version, @@ -68,7 +68,7 @@ def all(self: Self) -> DaskExpr: ) def col(self: Self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( + return self._expr.from_column_names( passthrough_column_names(column_names), function_name="col", backend_version=self._backend_version, @@ -76,7 +76,7 @@ def col(self: Self, *column_names: str) -> DaskExpr: ) def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr: - return DaskExpr.from_column_names( + return self._expr.from_column_names( partial(exclude_column_names, names=excluded_names), function_name="exclude", backend_version=self._backend_version, @@ -84,7 +84,7 @@ def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr: ) def nth(self: Self, *column_indices: int) -> DaskExpr: - return DaskExpr.from_column_indices( + return self._expr.from_column_indices( *column_indices, backend_version=self._backend_version, version=self._version ) @@ -99,7 +99,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions) return [dask_series[0].to_series()] - return DaskExpr( + return self._expr( func, depth=0, function_name="lit", @@ -121,7 +121,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [df._native_frame[df.columns[0]].size.to_series()] # coverage bug? this is definitely hit - return DaskExpr( # pragma: no cover + return self._expr( # pragma: no cover func, depth=0, function_name="len", @@ -138,7 +138,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) return [reduce(operator.and_, series)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", @@ -155,7 +155,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) return [reduce(operator.or_, series)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="any_horizontal", @@ -172,7 +172,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) return [dd.concat(series, axis=1).sum(axis=1)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", @@ -251,7 +251,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) ] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", @@ -269,7 +269,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [dd.concat(series, axis=1).min(axis=1)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", @@ -287,7 +287,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [dd.concat(series, axis=1).max(axis=1)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", @@ -335,7 +335,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [result] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="concat_str", diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 0e2f490b22..b9cb15fbe1 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -57,7 +57,7 @@ def _expr(self) -> type[DuckDBExpr]: return DuckDBExpr def all(self: Self) -> DuckDBExpr: - return DuckDBExpr.from_column_names( + return self._expr.from_column_names( get_column_names, function_name="all", backend_version=self._backend_version, @@ -131,7 +131,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [result] - return DuckDBExpr( + return self._expr( call=func, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -145,7 +145,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.and_, cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -159,7 +159,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.or_, cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="or_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -173,7 +173,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [FunctionExpression("greatest", *cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -187,7 +187,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [FunctionExpression("least", *cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -201,7 +201,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (CoalesceOperator(col, lit(0)) for _expr in exprs for col in _expr(df)) return [reduce(operator.add, cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -237,7 +237,7 @@ def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen: ) def col(self: Self, *column_names: str) -> DuckDBExpr: - return DuckDBExpr.from_column_names( + return self._expr.from_column_names( passthrough_column_names(column_names), function_name="col", backend_version=self._backend_version, @@ -245,7 +245,7 @@ def col(self: Self, *column_names: str) -> DuckDBExpr: ) def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr: - return DuckDBExpr.from_column_names( + return self._expr.from_column_names( partial(exclude_column_names, names=excluded_names), function_name="exclude", backend_version=self._backend_version, @@ -253,7 +253,7 @@ def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr: ) def nth(self: Self, *column_indices: int) -> DuckDBExpr: - return DuckDBExpr.from_column_indices( + return self._expr.from_column_indices( *column_indices, backend_version=self._backend_version, version=self._version ) @@ -267,7 +267,7 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: ] return [lit(value)] - return DuckDBExpr( + return self._expr( func, function_name="lit", evaluate_output_names=lambda _df: ["literal"], @@ -280,7 +280,7 @@ def len(self: Self) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [FunctionExpression("count")] - return DuckDBExpr( + return self._expr( call=func, function_name="len", evaluate_output_names=lambda _df: ["len"], From 5168ae50e6268cc127f604475ee7caede424c3a7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 14 Mar 2025 11:01:09 +0000 Subject: [PATCH 4/6] more coverage https://github.com/narwhals-dev/narwhals/actions/runs/13854776857/job/38769088411?pr=2202 --- narwhals/_spark_like/namespace.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 6c8802b820..7be93a5ab5 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -53,7 +53,7 @@ def _expr(self) -> type[SparkLikeExpr]: return SparkLikeExpr def all(self: Self) -> SparkLikeExpr: - return SparkLikeExpr.from_column_names( + return self._expr.from_column_names( get_column_names, function_name="all", implementation=self._implementation, @@ -62,7 +62,7 @@ def all(self: Self) -> SparkLikeExpr: ) def col(self: Self, *column_names: str) -> SparkLikeExpr: - return SparkLikeExpr.from_column_names( + return self._expr.from_column_names( passthrough_column_names(column_names), function_name="col", implementation=self._implementation, @@ -71,7 +71,7 @@ def col(self: Self, *column_names: str) -> SparkLikeExpr: ) def exclude(self: Self, excluded_names: Container[str]) -> SparkLikeExpr: - return SparkLikeExpr.from_column_names( + return self._expr.from_column_names( partial(exclude_column_names, names=excluded_names), function_name="exclude", implementation=self._implementation, @@ -80,7 +80,7 @@ def exclude(self: Self, excluded_names: Container[str]) -> SparkLikeExpr: ) def nth(self: Self, *column_indices: int) -> SparkLikeExpr: - return SparkLikeExpr.from_column_indices( + return self._expr.from_column_indices( *column_indices, backend_version=self._backend_version, version=self._version, @@ -98,7 +98,7 @@ def _lit(df: SparkLikeLazyFrame) -> list[Column]: return [column] - return SparkLikeExpr( + return self._expr( call=_lit, function_name="lit", evaluate_output_names=lambda _df: ["literal"], @@ -112,7 +112,7 @@ def len(self: Self) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.count("*")] - return SparkLikeExpr( + return self._expr( func, function_name="len", evaluate_output_names=lambda _df: ["len"], @@ -127,7 +127,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.and_, cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -142,7 +142,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.or_, cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -159,7 +159,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) return [reduce(operator.add, cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -188,7 +188,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) ] - return SparkLikeExpr( + return self._expr( call=func, function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -203,7 +203,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [df._F.greatest(*cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -218,7 +218,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [df._F.least(*cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -313,7 +313,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return [result] - return SparkLikeExpr( + return self._expr( call=func, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), From 5d609a70b253f0417da8a2ab66f30b7c0a03b2fd Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 14 Mar 2025 14:01:39 +0000 Subject: [PATCH 5/6] refactor: Implement `CompliantNamespace.(all|col|exclude|nth)` All backends (besides `polars`) now share the same implementation https://github.com/narwhals-dev/narwhals/pull/2202#discussion_r1994308977 --- narwhals/_arrow/namespace.py | 24 ------------------ narwhals/_compliant/expr.py | 28 +++++++++------------ narwhals/_compliant/namespace.py | 36 +++++++++++++++++++++++---- narwhals/_dask/expr.py | 17 ++++++------- narwhals/_dask/namespace.py | 34 ------------------------- narwhals/_duckdb/expr.py | 17 ++++++------- narwhals/_duckdb/namespace.py | 34 ------------------------- narwhals/_pandas_like/namespace.py | 25 ------------------- narwhals/_spark_like/expr.py | 23 +++++++---------- narwhals/_spark_like/namespace.py | 40 ------------------------------ 10 files changed, 66 insertions(+), 212 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 8a349f3c1c..730dc17fa4 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,13 +1,11 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -28,10 +26,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing import Callable @@ -67,20 +62,6 @@ def __init__( self._version = version # --- selection --- - def col(self: Self, *column_names: str) -> ArrowExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self: Self, *column_indices: int) -> ArrowExpr: - return self._expr.from_column_indices(*column_indices, context=self) def len(self: Self) -> ArrowExpr: # coverage bug? this is definitely hit @@ -98,11 +79,6 @@ def len(self: Self) -> ArrowExpr: version=self._version, ) - def all(self: Self) -> ArrowExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr: def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries._from_iterable( diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index b531667dc0..ee56b07c18 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -90,6 +90,18 @@ def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, ) -> CompliantNamespace[CompliantFrameT, Self]: ... + @classmethod + def from_column_names( + cls, + evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]], + /, + *, + function_name: str, + context: _FullContext, + ) -> Self: ... + @classmethod + def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ... + def is_null(self) -> Self: ... def abs(self) -> Self: ... def all(self) -> Self: ... @@ -330,22 +342,6 @@ def _from_series(cls, series: EagerSeriesT) -> Self: version=series._version, ) - @classmethod - def from_column_names( - cls, - evaluate_column_names: Callable[[EagerDataFrameT], Sequence[str]], - /, - *, - function_name: str, - context: _FullContext, - ) -> Self: ... - @classmethod - def from_column_indices( - cls, - *column_indices: int, - context: _FullContext, - ) -> Self: ... - def _reuse_series( self: Self, method_name: str, diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 6ad8967f74..742f0274ed 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Container @@ -13,21 +14,46 @@ from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT_co from narwhals.utils import deprecated +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from narwhals._compliant.selectors import CompliantSelectorNamespace from narwhals.dtypes import DType + from narwhals.utils import Implementation + from narwhals.utils import Version __all__ = ["CompliantNamespace", "EagerNamespace"] class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): - def col(self, *column_names: str) -> CompliantExprT: ... - def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... - def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ... - def nth(self, *column_indices: int) -> CompliantExprT: ... + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def all(self) -> CompliantExprT: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> CompliantExprT: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + def nth(self, *column_indices: int) -> CompliantExprT: + return self._expr.from_column_indices(*column_indices, context=self) + def len(self) -> CompliantExprT: ... - def all(self) -> CompliantExprT: ... + def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index a5f673f9d1..8d87a96e41 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -39,6 +39,7 @@ from narwhals._dask.namespace import DaskNamespace from narwhals.dtypes import DType from narwhals.utils import Version + from narwhals.utils import _FullContext class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]): @@ -100,8 +101,7 @@ def from_column_names( /, *, function_name: str, - backend_version: tuple[int, ...], - version: Version, + context: _FullContext, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: @@ -124,16 +124,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) @classmethod def from_column_indices( - cls: type[Self], - *column_indices: int, - backend_version: tuple[int, ...], - version: Version, + cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [ @@ -146,8 +143,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) def _from_call( diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 86c7a4b013..93c3cfaa99 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -26,9 +24,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -59,35 +54,6 @@ def __init__( self._backend_version = backend_version self._version = version - def all(self: Self) -> DaskExpr: - return self._expr.from_column_names( - get_column_names, - function_name="all", - backend_version=self._backend_version, - version=self._version, - ) - - def col(self: Self, *column_names: str) -> DaskExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), - function_name="col", - backend_version=self._backend_version, - version=self._version, - ) - - def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - backend_version=self._backend_version, - version=self._version, - ) - - def nth(self: Self, *column_indices: int) -> DaskExpr: - return self._expr.from_column_indices( - *column_indices, backend_version=self._backend_version, version=self._version - ) - def lit(self: Self, value: Any, dtype: DType | None) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 6dc21b8ffc..f450143afb 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -35,6 +35,7 @@ from narwhals._duckdb.namespace import DuckDBNamespace from narwhals.dtypes import DType from narwhals.utils import Version + from narwhals.utils import _FullContext class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): @@ -85,8 +86,7 @@ def from_column_names( /, *, function_name: str, - backend_version: tuple[int, ...], - version: Version, + context: _FullContext, ) -> Self: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ColumnExpression(col_name) for col_name in evaluate_column_names(df)] @@ -96,16 +96,13 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) @classmethod def from_column_indices( - cls: type[Self], - *column_indices: int, - backend_version: tuple[int, ...], - version: Version, + cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: columns = df.columns @@ -117,8 +114,8 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) def _from_call( diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index b9cb15fbe1..121b91916d 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -26,9 +24,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: import duckdb @@ -56,14 +51,6 @@ def selectors(self: Self) -> DuckDBSelectorNamespace: def _expr(self) -> type[DuckDBExpr]: return DuckDBExpr - def all(self: Self) -> DuckDBExpr: - return self._expr.from_column_names( - get_column_names, - function_name="all", - backend_version=self._backend_version, - version=self._version, - ) - def concat( self: Self, items: Iterable[DuckDBLazyFrame], @@ -236,27 +223,6 @@ def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen: version=self._version, ) - def col(self: Self, *column_names: str) -> DuckDBExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), - function_name="col", - backend_version=self._backend_version, - version=self._version, - ) - - def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - backend_version=self._backend_version, - version=self._version, - ) - - def nth(self: Self, *column_indices: int) -> DuckDBExpr: - return self._expr.from_column_indices( - *column_indices, backend_version=self._backend_version, version=self._version - ) - def lit(self: Self, value: Any, dtype: DType | None) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: if dtype is not None: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index e9902e6f8b..33f2543e28 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -24,10 +22,7 @@ from narwhals._pandas_like.utils import extract_dataframe_comparand from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -75,26 +70,6 @@ def _create_compliant_series(self: Self, value: Any) -> PandasLikeSeries: ) # --- selection --- - def col(self: Self, *column_names: str) -> PandasLikeExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self: Self, excluded_names: Container[str]) -> PandasLikeExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self: Self, *column_indices: int) -> PandasLikeExpr: - return self._expr.from_column_indices(*column_indices, context=self) - - def all(self: Self) -> PandasLikeExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - def lit(self: Self, value: Any, dtype: DType | None) -> PandasLikeExpr: def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: pandas_series = self._series._from_iterable( diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 7f6a07de83..f7190fc060 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -35,6 +35,7 @@ from narwhals._spark_like.typing import WindowFunction from narwhals.dtypes import DType from narwhals.utils import Version + from narwhals.utils import _FullContext class SparkLikeExpr(LazyExpr["SparkLikeLazyFrame", "Column"]): @@ -129,9 +130,7 @@ def from_column_names( /, *, function_name: str, - implementation: Implementation, - backend_version: tuple[int, ...], - version: Version, + context: _FullContext, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col_name) for col_name in evaluate_column_names(df)] @@ -141,18 +140,14 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, - backend_version=backend_version, - version=version, - implementation=implementation, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, ) @classmethod def from_column_indices( - cls: type[Self], - *column_indices: int, - backend_version: tuple[int, ...], - version: Version, - implementation: Implementation, + cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: columns = df.columns @@ -163,9 +158,9 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - backend_version=backend_version, - version=version, - implementation=implementation, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, ) def _from_call( diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 7be93a5ab5..6988ba9cc3 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -19,9 +17,6 @@ from narwhals._spark_like.selectors import SparkLikeSelectorNamespace from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from sqlframe.base.column import Column @@ -52,41 +47,6 @@ def selectors(self: Self) -> SparkLikeSelectorNamespace: def _expr(self) -> type[SparkLikeExpr]: return SparkLikeExpr - def all(self: Self) -> SparkLikeExpr: - return self._expr.from_column_names( - get_column_names, - function_name="all", - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - - def col(self: Self, *column_names: str) -> SparkLikeExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), - function_name="col", - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - - def exclude(self: Self, excluded_names: Container[str]) -> SparkLikeExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - - def nth(self: Self, *column_indices: int) -> SparkLikeExpr: - return self._expr.from_column_indices( - *column_indices, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) - def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: def _lit(df: SparkLikeLazyFrame) -> list[Column]: column = df._F.lit(value) From a621fa315793961d6adab074495ac424b2aa6686 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Fri, 14 Mar 2025 14:14:25 +0000 Subject: [PATCH 6/6] refactor(typing): Simplify `extract_compliant` - `FrameT` isn't relevant to what is used - Only needed before as `CompliantNamespace` didn't expose an `ExprT` --- narwhals/_compliant/__init__.py | 2 ++ narwhals/_expression_parsing.py | 8 +++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index c4dbdcd9ef..a81bd06ff8 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -16,6 +16,7 @@ from narwhals._compliant.selectors import LazySelectorNamespace from narwhals._compliant.series import CompliantSeries from narwhals._compliant.series import EagerSeries +from narwhals._compliant.typing import CompliantExprT from narwhals._compliant.typing import CompliantFrameT from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co from narwhals._compliant.typing import CompliantSeriesT_co @@ -26,6 +27,7 @@ __all__ = [ "CompliantDataFrame", "CompliantExpr", + "CompliantExprT", "CompliantFrameT", "CompliantLazyFrame", "CompliantNamespace", diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index a8cb3518f9..8ba9e776b9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -24,6 +24,7 @@ from typing_extensions import TypeIs from narwhals._compliant import CompliantExpr + from narwhals._compliant import CompliantExprT from narwhals._compliant import CompliantFrameT from narwhals._compliant import CompliantNamespace from narwhals.expr import Expr @@ -90,11 +91,8 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: def extract_compliant( - plx: CompliantNamespace[CompliantFrameT, Any], - other: Any, - *, - str_as_lit: bool, -) -> CompliantExpr[CompliantFrameT, Any] | object: + plx: CompliantNamespace[Any, CompliantExprT], other: Any, *, str_as_lit: bool +) -> CompliantExprT | object: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: