diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 7097c37767..34ad9ea7c2 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -1,13 +1,11 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -29,10 +27,7 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing import Callable @@ -41,7 +36,6 @@ from typing_extensions import TypeAlias from narwhals._arrow.typing import Incomplete - from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType from narwhals.utils import Version @@ -69,20 +63,6 @@ def __init__( self._version = version # --- selection --- - def col(self: Self, *column_names: str) -> ArrowExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self: Self, *column_indices: int) -> ArrowExpr: - return self._expr.from_column_indices(*column_indices, context=self) def len(self: Self) -> ArrowExpr: # coverage bug? this is definitely hit @@ -100,11 +80,6 @@ def len(self: Self) -> ArrowExpr: version=self._version, ) - def all(self: Self) -> ArrowExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr: def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: arrow_series = ArrowSeries._from_iterable( @@ -167,7 +142,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: context=self, ) - def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr: + def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: dtypes = import_dtypes_module(self._version) def func(df: ArrowDataFrame) -> list[ArrowSeries]: diff --git a/narwhals/_compliant/__init__.py b/narwhals/_compliant/__init__.py index c4dbdcd9ef..a81bd06ff8 100644 --- a/narwhals/_compliant/__init__.py +++ b/narwhals/_compliant/__init__.py @@ -16,6 +16,7 @@ from narwhals._compliant.selectors import LazySelectorNamespace from narwhals._compliant.series import CompliantSeries from narwhals._compliant.series import EagerSeries +from narwhals._compliant.typing import CompliantExprT from narwhals._compliant.typing import CompliantFrameT from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co from narwhals._compliant.typing import CompliantSeriesT_co @@ -26,6 +27,7 @@ __all__ = [ "CompliantDataFrame", "CompliantExpr", + "CompliantExprT", "CompliantFrameT", "CompliantLazyFrame", "CompliantNamespace", diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index 0079f6faeb..ee56b07c18 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -89,7 +89,19 @@ def __call__( def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__( self, - ) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... + ) -> CompliantNamespace[CompliantFrameT, Self]: ... + @classmethod + def from_column_names( + cls, + evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]], + /, + *, + function_name: str, + context: _FullContext, + ) -> Self: ... + @classmethod + def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ... + def is_null(self) -> Self: ... def abs(self) -> Self: ... def all(self) -> Self: ... @@ -330,22 +342,6 @@ def _from_series(cls, series: EagerSeriesT) -> Self: version=series._version, ) - @classmethod - def from_column_names( - cls, - evaluate_column_names: Callable[[EagerDataFrameT], Sequence[str]], - /, - *, - function_name: str, - context: _FullContext, - ) -> Self: ... - @classmethod - def from_column_indices( - cls, - *column_indices: int, - context: _FullContext, - ) -> Self: ... - def _reuse_series( self: Self, method_name: str, diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index 688f2770c2..742f0274ed 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -1,44 +1,90 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING from typing import Any +from typing import Container +from typing import Iterable +from typing import Literal from typing import Protocol +from narwhals._compliant.typing import CompliantExprT from narwhals._compliant.typing import CompliantFrameT -from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co from narwhals._compliant.typing import EagerDataFrameT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT_co from narwhals.utils import deprecated +from narwhals.utils import exclude_column_names +from narwhals.utils import get_column_names +from narwhals.utils import passthrough_column_names if TYPE_CHECKING: - from narwhals._compliant.expr import CompliantExpr from narwhals._compliant.selectors import CompliantSelectorNamespace from narwhals.dtypes import DType + from narwhals.utils import Implementation + from narwhals.utils import Version __all__ = ["CompliantNamespace", "EagerNamespace"] -class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]): - def col( - self, *column_names: str - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... - def lit( - self, value: Any, dtype: DType | None - ) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ... +class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]): + _implementation: Implementation + _backend_version: tuple[int, ...] + _version: Version + + def all(self) -> CompliantExprT: + return self._expr.from_column_names( + get_column_names, function_name="all", context=self + ) + + def col(self, *column_names: str) -> CompliantExprT: + return self._expr.from_column_names( + passthrough_column_names(column_names), function_name="col", context=self + ) + + def exclude(self, excluded_names: Container[str]) -> CompliantExprT: + return self._expr.from_column_names( + partial(exclude_column_names, names=excluded_names), + function_name="exclude", + context=self, + ) + + def nth(self, *column_indices: int) -> CompliantExprT: + return self._expr.from_column_indices(*column_indices, context=self) + + def len(self) -> CompliantExprT: ... + def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ... + def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... + def concat( + self, + items: Iterable[CompliantFrameT], + *, + how: Literal["horizontal", "vertical", "diagonal"], + ) -> CompliantFrameT: ... + def when(self, predicate: CompliantExprT) -> Any: ... + def concat_str( + self, + *exprs: CompliantExprT, + separator: str, + ignore_nulls: bool, + ) -> CompliantExprT: ... @property def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ... + @property + def _expr(self) -> type[CompliantExprT]: ... class EagerNamespace( - CompliantNamespace[EagerDataFrameT, EagerSeriesT_co], + CompliantNamespace[EagerDataFrameT, EagerExprT], Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT], ): - @property - def _expr(self) -> type[EagerExprT]: ... @property def _series(self) -> type[EagerSeriesT_co]: ... - def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ... @deprecated( "Internally used for `numpy.ndarray` -> `CompliantSeries`\n" diff --git a/narwhals/_compliant/typing.py b/narwhals/_compliant/typing.py index 2513097a50..ba970e4dbb 100644 --- a/narwhals/_compliant/typing.py +++ b/narwhals/_compliant/typing.py @@ -42,6 +42,7 @@ CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]") CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame") IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co" +CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]") EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]") EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]") diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index a5f673f9d1..8d87a96e41 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -39,6 +39,7 @@ from narwhals._dask.namespace import DaskNamespace from narwhals.dtypes import DType from narwhals.utils import Version + from narwhals.utils import _FullContext class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]): @@ -100,8 +101,7 @@ def from_column_names( /, *, function_name: str, - backend_version: tuple[int, ...], - version: Version, + context: _FullContext, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: @@ -124,16 +124,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) @classmethod def from_column_indices( - cls: type[Self], - *column_indices: int, - backend_version: tuple[int, ...], - version: Version, + cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [ @@ -146,8 +143,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) def _from_call( diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 48627382c4..9ea92b75c9 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -26,9 +24,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -42,48 +37,23 @@ import dask_expr as dx -class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): +class DaskNamespace(CompliantNamespace[DaskLazyFrame, "DaskExpr"]): _implementation: Implementation = Implementation.DASK @property def selectors(self: Self) -> DaskSelectorNamespace: return DaskSelectorNamespace(self) + @property + def _expr(self) -> type[DaskExpr]: + return DaskExpr + def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version ) -> None: self._backend_version = backend_version self._version = version - def all(self: Self) -> DaskExpr: - return DaskExpr.from_column_names( - get_column_names, - function_name="all", - backend_version=self._backend_version, - version=self._version, - ) - - def col(self: Self, *column_names: str) -> DaskExpr: - return DaskExpr.from_column_names( - passthrough_column_names(column_names), - function_name="col", - backend_version=self._backend_version, - version=self._version, - ) - - def exclude(self: Self, excluded_names: Container[str]) -> DaskExpr: - return DaskExpr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - backend_version=self._backend_version, - version=self._version, - ) - - def nth(self: Self, *column_indices: int) -> DaskExpr: - return DaskExpr.from_column_indices( - *column_indices, backend_version=self._backend_version, version=self._version - ) - def lit(self: Self, value: Any, dtype: DType | None) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: if dtype is not None: @@ -95,7 +65,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions) return [dask_series[0].to_series()] - return DaskExpr( + return self._expr( func, depth=0, function_name="lit", @@ -110,7 +80,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: # We don't allow dataframes with 0 columns, so `[0]` is safe. return [df._native_frame[df.columns[0]].size.to_series()] - return DaskExpr( + return self._expr( func, depth=0, function_name="len", @@ -127,7 +97,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) return [reduce(operator.and_, series)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", @@ -144,7 +114,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) return [reduce(operator.or_, series)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="any_horizontal", @@ -161,7 +131,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) return [dd.concat(series, axis=1).sum(axis=1)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", @@ -240,7 +210,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) ] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", @@ -258,7 +228,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [dd.concat(series, axis=1).min(axis=1)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", @@ -276,7 +246,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [dd.concat(series, axis=1).max(axis=1)] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", @@ -324,7 +294,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: return [result] - return DaskExpr( + return self._expr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="concat_str", diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 6dc21b8ffc..f450143afb 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -35,6 +35,7 @@ from narwhals._duckdb.namespace import DuckDBNamespace from narwhals.dtypes import DType from narwhals.utils import Version + from narwhals.utils import _FullContext class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]): @@ -85,8 +86,7 @@ def from_column_names( /, *, function_name: str, - backend_version: tuple[int, ...], - version: Version, + context: _FullContext, ) -> Self: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [ColumnExpression(col_name) for col_name in evaluate_column_names(df)] @@ -96,16 +96,13 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) @classmethod def from_column_indices( - cls: type[Self], - *column_indices: int, - backend_version: tuple[int, ...], - version: Version, + cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: columns = df.columns @@ -117,8 +114,8 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - backend_version=backend_version, - version=version, + backend_version=context._backend_version, + version=context._version, ) def _from_call( diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index f245c74e27..121b91916d 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -1,12 +1,11 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container +from typing import Iterable from typing import Literal from typing import Sequence @@ -25,9 +24,6 @@ from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.utils import Implementation -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: import duckdb @@ -38,7 +34,7 @@ from narwhals.utils import Version -class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]): +class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "DuckDBExpr"]): _implementation: Implementation = Implementation.DUCKDB def __init__( @@ -51,17 +47,13 @@ def __init__( def selectors(self: Self) -> DuckDBSelectorNamespace: return DuckDBSelectorNamespace(self) - def all(self: Self) -> DuckDBExpr: - return DuckDBExpr.from_column_names( - get_column_names, - function_name="all", - backend_version=self._backend_version, - version=self._version, - ) + @property + def _expr(self) -> type[DuckDBExpr]: + return DuckDBExpr def concat( self: Self, - items: Sequence[DuckDBLazyFrame], + items: Iterable[DuckDBLazyFrame], *, how: Literal["horizontal", "vertical", "diagonal"], ) -> DuckDBLazyFrame: @@ -71,6 +63,7 @@ def concat( if how == "diagonal": msg = "Not implemented yet" raise NotImplementedError(msg) + items = list(items) first = items[0] schema = first.schema if how == "vertical" and not all(x.schema == schema for x in items[1:]): @@ -125,7 +118,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [result] - return DuckDBExpr( + return self._expr( call=func, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -139,7 +132,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.and_, cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -153,7 +146,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.or_, cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="or_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -167,7 +160,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [FunctionExpression("greatest", *cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -181,7 +174,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (c for _expr in exprs for c in _expr(df)) return [FunctionExpression("least", *cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -195,7 +188,7 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = (CoalesceOperator(col, lit(0)) for _expr in exprs for col in _expr(df)) return [reduce(operator.add, cols)] - return DuckDBExpr( + return self._expr( call=func, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -230,27 +223,6 @@ def when(self: Self, predicate: DuckDBExpr) -> DuckDBWhen: version=self._version, ) - def col(self: Self, *column_names: str) -> DuckDBExpr: - return DuckDBExpr.from_column_names( - passthrough_column_names(column_names), - function_name="col", - backend_version=self._backend_version, - version=self._version, - ) - - def exclude(self: Self, excluded_names: Container[str]) -> DuckDBExpr: - return DuckDBExpr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - backend_version=self._backend_version, - version=self._version, - ) - - def nth(self: Self, *column_indices: int) -> DuckDBExpr: - return DuckDBExpr.from_column_indices( - *column_indices, backend_version=self._backend_version, version=self._version - ) - def lit(self: Self, value: Any, dtype: DType | None) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: if dtype is not None: @@ -261,7 +233,7 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: ] return [lit(value)] - return DuckDBExpr( + return self._expr( func, function_name="lit", evaluate_output_names=lambda _df: ["literal"], @@ -274,7 +246,7 @@ def len(self: Self) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: return [FunctionExpression("count")] - return DuckDBExpr( + return self._expr( call=func, function_name="len", evaluate_output_names=lambda _df: ["len"], diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index f5d091c4eb..8ba9e776b9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -24,9 +24,9 @@ from typing_extensions import TypeIs from narwhals._compliant import CompliantExpr + from narwhals._compliant import CompliantExprT from narwhals._compliant import CompliantFrameT from narwhals._compliant import CompliantNamespace - from narwhals._compliant import CompliantSeriesOrNativeExprT_co from narwhals.expr import Expr from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame @@ -91,11 +91,8 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: def extract_compliant( - plx: CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co], - other: Any, - *, - str_as_lit: bool, -) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | object: + plx: CompliantNamespace[Any, CompliantExprT], other: Any, *, str_as_lit: bool +) -> CompliantExprT | object: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index e9902e6f8b..33f2543e28 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -24,10 +22,7 @@ from narwhals._pandas_like.utils import extract_dataframe_comparand from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import vertical_concat -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names from narwhals.utils import import_dtypes_module -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from typing_extensions import Self @@ -75,26 +70,6 @@ def _create_compliant_series(self: Self, value: Any) -> PandasLikeSeries: ) # --- selection --- - def col(self: Self, *column_names: str) -> PandasLikeExpr: - return self._expr.from_column_names( - passthrough_column_names(column_names), function_name="col", context=self - ) - - def exclude(self: Self, excluded_names: Container[str]) -> PandasLikeExpr: - return self._expr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - context=self, - ) - - def nth(self: Self, *column_indices: int) -> PandasLikeExpr: - return self._expr.from_column_indices(*column_indices, context=self) - - def all(self: Self) -> PandasLikeExpr: - return self._expr.from_column_names( - get_column_names, function_name="all", context=self - ) - def lit(self: Self, value: Any, dtype: DType | None) -> PandasLikeExpr: def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: pandas_series = self._series._from_iterable( diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 7f6a07de83..f7190fc060 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -35,6 +35,7 @@ from narwhals._spark_like.typing import WindowFunction from narwhals.dtypes import DType from narwhals.utils import Version + from narwhals.utils import _FullContext class SparkLikeExpr(LazyExpr["SparkLikeLazyFrame", "Column"]): @@ -129,9 +130,7 @@ def from_column_names( /, *, function_name: str, - implementation: Implementation, - backend_version: tuple[int, ...], - version: Version, + context: _FullContext, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col_name) for col_name in evaluate_column_names(df)] @@ -141,18 +140,14 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, - backend_version=backend_version, - version=version, - implementation=implementation, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, ) @classmethod def from_column_indices( - cls: type[Self], - *column_indices: int, - backend_version: tuple[int, ...], - version: Version, - implementation: Implementation, + cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: columns = df.columns @@ -163,9 +158,9 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, - backend_version=backend_version, - version=version, - implementation=implementation, + backend_version=context._backend_version, + version=context._version, + implementation=context._implementation, ) def _from_call( diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index ed414e196c..6988ba9cc3 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -1,12 +1,10 @@ from __future__ import annotations import operator -from functools import partial from functools import reduce from typing import TYPE_CHECKING from typing import Any from typing import Callable -from typing import Container from typing import Iterable from typing import Literal from typing import Sequence @@ -19,9 +17,6 @@ from narwhals._spark_like.selectors import SparkLikeSelectorNamespace from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype -from narwhals.utils import exclude_column_names -from narwhals.utils import get_column_names -from narwhals.utils import passthrough_column_names if TYPE_CHECKING: from sqlframe.base.column import Column @@ -32,7 +27,7 @@ from narwhals.utils import Version -class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): +class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "SparkLikeExpr"]): def __init__( self: Self, *, @@ -48,40 +43,9 @@ def __init__( def selectors(self: Self) -> SparkLikeSelectorNamespace: return SparkLikeSelectorNamespace(self) - def all(self: Self) -> SparkLikeExpr: - return SparkLikeExpr.from_column_names( - get_column_names, - function_name="all", - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - - def col(self: Self, *column_names: str) -> SparkLikeExpr: - return SparkLikeExpr.from_column_names( - passthrough_column_names(column_names), - function_name="col", - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - - def exclude(self: Self, excluded_names: Container[str]) -> SparkLikeExpr: - return SparkLikeExpr.from_column_names( - partial(exclude_column_names, names=excluded_names), - function_name="exclude", - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ) - - def nth(self: Self, *column_indices: int) -> SparkLikeExpr: - return SparkLikeExpr.from_column_indices( - *column_indices, - backend_version=self._backend_version, - version=self._version, - implementation=self._implementation, - ) + @property + def _expr(self) -> type[SparkLikeExpr]: + return SparkLikeExpr def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: def _lit(df: SparkLikeLazyFrame) -> list[Column]: @@ -94,7 +58,7 @@ def _lit(df: SparkLikeLazyFrame) -> list[Column]: return [column] - return SparkLikeExpr( + return self._expr( call=_lit, function_name="lit", evaluate_output_names=lambda _df: ["literal"], @@ -108,7 +72,7 @@ def len(self: Self) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.count("*")] - return SparkLikeExpr( + return self._expr( func, function_name="len", evaluate_output_names=lambda _df: ["len"], @@ -123,7 +87,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.and_, cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="all_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -138,7 +102,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [reduce(operator.or_, cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="any_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -155,7 +119,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) return [reduce(operator.add, cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="sum_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -184,7 +148,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) ] - return SparkLikeExpr( + return self._expr( call=func, function_name="mean_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -199,7 +163,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [df._F.greatest(*cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="max_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -214,7 +178,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = (c for _expr in exprs for c in _expr(df)) return [df._F.least(*cols)] - return SparkLikeExpr( + return self._expr( call=func, function_name="min_horizontal", evaluate_output_names=combine_evaluate_output_names(*exprs), @@ -309,7 +273,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: return [result] - return SparkLikeExpr( + return self._expr( call=func, function_name="concat_str", evaluate_output_names=combine_evaluate_output_names(*exprs),