Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 1 addition & 26 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import operator
from functools import partial
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Container
from typing import Iterable
from typing import Literal
from typing import Sequence
Expand All @@ -29,10 +27,7 @@
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.utils import Implementation
from narwhals.utils import exclude_column_names
from narwhals.utils import get_column_names
from narwhals.utils import import_dtypes_module
from narwhals.utils import passthrough_column_names

if TYPE_CHECKING:
from typing import Callable
Expand All @@ -41,7 +36,6 @@
from typing_extensions import TypeAlias

from narwhals._arrow.typing import Incomplete
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.utils import Version

Expand Down Expand Up @@ -69,20 +63,6 @@ def __init__(
self._version = version

# --- selection ---
def col(self: Self, *column_names: str) -> ArrowExpr:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)

def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)

def nth(self: Self, *column_indices: int) -> ArrowExpr:
return self._expr.from_column_indices(*column_indices, context=self)

def len(self: Self) -> ArrowExpr:
# coverage bug? this is definitely hit
Expand All @@ -100,11 +80,6 @@ def len(self: Self) -> ArrowExpr:
version=self._version,
)

def all(self: Self) -> ArrowExpr:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)

def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries._from_iterable(
Expand Down Expand Up @@ -167,7 +142,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
context=self,
)

def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr:
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
Copy link
Member Author

@dangotbanned dangotbanned Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only got flagged after filling out the protocol.

dtypes = import_dtypes_module(self._version)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand Down
2 changes: 2 additions & 0 deletions narwhals/_compliant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from narwhals._compliant.selectors import LazySelectorNamespace
from narwhals._compliant.series import CompliantSeries
from narwhals._compliant.series import EagerSeries
from narwhals._compliant.typing import CompliantExprT
from narwhals._compliant.typing import CompliantFrameT
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
from narwhals._compliant.typing import CompliantSeriesT_co
Expand All @@ -26,6 +27,7 @@
__all__ = [
"CompliantDataFrame",
"CompliantExpr",
"CompliantExprT",
"CompliantFrameT",
"CompliantLazyFrame",
"CompliantNamespace",
Expand Down
30 changes: 13 additions & 17 deletions narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,19 @@ def __call__(
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(
self,
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
) -> CompliantNamespace[CompliantFrameT, Self]: ...
@classmethod
def from_column_names(
cls,
evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]],
/,
*,
function_name: str,
context: _FullContext,
) -> Self: ...
@classmethod
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ...

def is_null(self) -> Self: ...
def abs(self) -> Self: ...
def all(self) -> Self: ...
Expand Down Expand Up @@ -330,22 +342,6 @@ def _from_series(cls, series: EagerSeriesT) -> Self:
version=series._version,
)

@classmethod
def from_column_names(
cls,
evaluate_column_names: Callable[[EagerDataFrameT], Sequence[str]],
/,
*,
function_name: str,
context: _FullContext,
) -> Self: ...
@classmethod
def from_column_indices(
cls,
*column_indices: int,
context: _FullContext,
) -> Self: ...

def _reuse_series(
self: Self,
method_name: str,
Expand Down
72 changes: 59 additions & 13 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,90 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Container
from typing import Iterable
from typing import Literal
from typing import Protocol

from narwhals._compliant.typing import CompliantExprT
from narwhals._compliant.typing import CompliantFrameT
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
from narwhals._compliant.typing import EagerDataFrameT
from narwhals._compliant.typing import EagerExprT
from narwhals._compliant.typing import EagerSeriesT_co
from narwhals.utils import deprecated
from narwhals.utils import exclude_column_names
from narwhals.utils import get_column_names
from narwhals.utils import passthrough_column_names

if TYPE_CHECKING:
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals.dtypes import DType
from narwhals.utils import Implementation
from narwhals.utils import Version

__all__ = ["CompliantNamespace", "EagerNamespace"]


class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
def col(
self, *column_names: str
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
def lit(
self, value: Any, dtype: DType | None
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version

def all(self) -> CompliantExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)

def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)

def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)

def nth(self, *column_indices: int) -> CompliantExprT:
return self._expr.from_column_indices(*column_indices, context=self)

def len(self) -> CompliantExprT: ...
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self,
items: Iterable[CompliantFrameT],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> CompliantFrameT: ...
def when(self, predicate: CompliantExprT) -> Any: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Return type would require another generic protocol for When
  • Doing that is a pretty low priority for now

def concat_str(
self,
*exprs: CompliantExprT,
separator: str,
ignore_nulls: bool,
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
@property
def _expr(self) -> type[CompliantExprT]: ...


class EagerNamespace(
CompliantNamespace[EagerDataFrameT, EagerSeriesT_co],
CompliantNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
):
@property
def _expr(self) -> type[EagerExprT]: ...
@property
def _series(self) -> type[EagerSeriesT_co]: ...
def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ...

@deprecated(
"Internally used for `numpy.ndarray` -> `CompliantSeries`\n"
Expand Down
1 change: 1 addition & 0 deletions narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]")
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame")
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]")

EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]")
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")
Expand Down
17 changes: 7 additions & 10 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from narwhals._dask.namespace import DaskNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.utils import _FullContext


class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]):
Expand Down Expand Up @@ -100,8 +101,7 @@ def from_column_names(
/,
*,
function_name: str,
backend_version: tuple[int, ...],
version: Version,
context: _FullContext,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
Expand All @@ -124,16 +124,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
backend_version=backend_version,
version=version,
backend_version=context._backend_version,
version=context._version,
)

@classmethod
def from_column_indices(
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
version: Version,
cls: type[Self], *column_indices: int, context: _FullContext
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [
Expand All @@ -146,8 +143,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
function_name="nth",
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
alias_output_names=None,
backend_version=backend_version,
version=version,
backend_version=context._backend_version,
version=context._version,
)

def _from_call(
Expand Down
Loading
Loading