Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from typing_extensions import TypeAlias

from narwhals._arrow.typing import Incomplete
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.utils import Version

Expand Down Expand Up @@ -166,7 +165,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
context=self,
)

def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr:
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
Copy link
Member Author

@dangotbanned dangotbanned Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only got flagged after filling out the protocol.

dtypes = import_dtypes_module(self._version)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(
self,
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
) -> CompliantNamespace[CompliantFrameT, Self]: ...
def is_null(self) -> Self: ...
def abs(self) -> Self: ...
def all(self) -> Self: ...
Expand Down
46 changes: 33 additions & 13 deletions narwhals/_compliant/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,63 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Container
from typing import Iterable
from typing import Literal
from typing import Protocol

from narwhals._compliant.typing import CompliantExprT
from narwhals._compliant.typing import CompliantFrameT
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
from narwhals._compliant.typing import EagerDataFrameT
from narwhals._compliant.typing import EagerExprT
from narwhals._compliant.typing import EagerSeriesT_co
from narwhals.utils import deprecated

if TYPE_CHECKING:
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals.dtypes import DType

__all__ = ["CompliantNamespace", "EagerNamespace"]


class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
def col(
self, *column_names: str
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
def lit(
self, value: Any, dtype: DType | None
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
def col(self, *column_names: str) -> CompliantExprT: ...
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ...
def nth(self, *column_indices: int) -> CompliantExprT: ...
def len(self) -> CompliantExprT: ...
def all(self) -> CompliantExprT: ...
def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self,
items: Iterable[CompliantFrameT],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> CompliantFrameT: ...
def when(self, predicate: CompliantExprT) -> Any: ...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Return type would require another generic protocol for When
  • Doing that is a pretty low priority for now

def concat_str(
self,
*exprs: CompliantExprT,
separator: str,
ignore_nulls: bool,
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
@property
def _expr(self) -> type[CompliantExprT]: ...


class EagerNamespace(
CompliantNamespace[EagerDataFrameT, EagerSeriesT_co],
CompliantNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
):
@property
def _expr(self) -> type[EagerExprT]: ...
@property
def _series(self) -> type[EagerSeriesT_co]: ...
def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ...

@deprecated(
"Internally used for `numpy.ndarray` -> `CompliantSeries`\n"
Expand Down
1 change: 1 addition & 0 deletions narwhals/_compliant/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]")
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame")
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]")

EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]")
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")
Expand Down
6 changes: 5 additions & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
import dask_expr as dx


class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]):
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "DaskExpr"]):
_implementation: Implementation = Implementation.DASK

@property
def selectors(self: Self) -> DaskSelectorNamespace:
return DaskSelectorNamespace(self)

@property
def _expr(self) -> type[DaskExpr]:
return DaskExpr

def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
) -> None:
Expand Down
10 changes: 8 additions & 2 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any
from typing import Callable
from typing import Container
from typing import Iterable
from typing import Literal
from typing import Sequence

Expand Down Expand Up @@ -38,7 +39,7 @@
from narwhals.utils import Version


class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]):
class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "DuckDBExpr"]):
_implementation: Implementation = Implementation.DUCKDB

def __init__(
Expand All @@ -51,6 +52,10 @@ def __init__(
def selectors(self: Self) -> DuckDBSelectorNamespace:
return DuckDBSelectorNamespace(self)

@property
def _expr(self) -> type[DuckDBExpr]:
return DuckDBExpr

def all(self: Self) -> DuckDBExpr:
return DuckDBExpr.from_column_names(
get_column_names,
Expand All @@ -61,7 +66,7 @@ def all(self: Self) -> DuckDBExpr:

def concat(
self: Self,
items: Sequence[DuckDBLazyFrame],
items: Iterable[DuckDBLazyFrame],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> DuckDBLazyFrame:
Expand All @@ -71,6 +76,7 @@ def concat(
if how == "diagonal":
msg = "Not implemented yet"
raise NotImplementedError(msg)
items = list(items)
first = items[0]
schema = first.schema
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
Expand Down
5 changes: 2 additions & 3 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from narwhals._compliant import CompliantExpr
from narwhals._compliant import CompliantFrameT
from narwhals._compliant import CompliantNamespace
from narwhals._compliant import CompliantSeriesOrNativeExprT_co
from narwhals.expr import Expr
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantLazyFrame
Expand Down Expand Up @@ -91,11 +90,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:


def extract_compliant(
plx: CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co],
plx: CompliantNamespace[CompliantFrameT, Any],
other: Any,
*,
str_as_lit: bool,
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | object:
) -> CompliantExpr[CompliantFrameT, Any] | object:
if is_expr(other):
return other._to_compliant_expr(plx)
if isinstance(other, str) and not str_as_lit:
Expand Down
6 changes: 5 additions & 1 deletion narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from narwhals.utils import Version


class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]):
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "SparkLikeExpr"]):
def __init__(
self: Self,
*,
Expand All @@ -48,6 +48,10 @@ def __init__(
def selectors(self: Self) -> SparkLikeSelectorNamespace:
return SparkLikeSelectorNamespace(self)

@property
def _expr(self) -> type[SparkLikeExpr]:
return SparkLikeExpr

def all(self: Self) -> SparkLikeExpr:
return SparkLikeExpr.from_column_names(
get_column_names,
Copy link
Member Author

@dangotbanned dangotbanned Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooh we might be able to implement this (and maybe a few other methods) higher up in the protocol 🀯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from narwhals.utils import get_column_names

class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
    @property
    def _expr(self) -> type[CompliantExprT]: ...

    def all(self) -> CompliantExprT:
        return self._expr.from_column_names(
            get_column_names,
            function_name="all",
            implementation=self._implementation,
            backend_version=self._backend_version,
            version=self._version,
        )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already most of the way there in

# --- selection ---
def col(self: Self, *column_names: str) -> PandasLikeExpr:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)
def exclude(self: Self, excluded_names: Container[str]) -> PandasLikeExpr:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)
def nth(self: Self, *column_indices: int) -> PandasLikeExpr:
return self._expr.from_column_indices(*column_indices, context=self)
def all(self: Self) -> PandasLikeExpr:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FBruzzesi ignore this if you're not fully back yet πŸ™‚

Thought you might like to see this new possibility - following the recent spec-ing of the Compliant* protocols

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went ahead with that in (5d609a7)

We now have default implementations for 4x CompliantNamespace methods 😁

Expand Down
Loading