Skip to content

Commit 4eb66a0

Browse files
committed
feat(typing): Fill out CompliantNamespace protocol
1 parent 7611bd4 commit 4eb66a0

File tree

8 files changed

+56
-20
lines changed

8 files changed

+56
-20
lines changed

narwhals/_arrow/namespace.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from typing_extensions import TypeAlias
4141

4242
from narwhals._arrow.typing import Incomplete
43-
from narwhals._arrow.typing import IntoArrowExpr
4443
from narwhals.dtypes import DType
4544
from narwhals.utils import Version
4645

@@ -166,7 +165,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
166165
context=self,
167166
)
168167

169-
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr:
168+
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
170169
dtypes = import_dtypes_module(self._version)
171170

172171
def func(df: ArrowDataFrame) -> list[ArrowSeries]:

narwhals/_compliant/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __call__(
8989
def __narwhals_expr__(self) -> None: ...
9090
def __narwhals_namespace__(
9191
self,
92-
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
92+
) -> CompliantNamespace[CompliantFrameT, Self]: ...
9393
def is_null(self) -> Self: ...
9494
def abs(self) -> Self: ...
9595
def all(self) -> Self: ...

narwhals/_compliant/namespace.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,59 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import Container
6+
from typing import Iterable
7+
from typing import Literal
58
from typing import Protocol
69

10+
from narwhals._compliant.typing import CompliantExprT
711
from narwhals._compliant.typing import CompliantFrameT
8-
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
912
from narwhals._compliant.typing import EagerDataFrameT
1013
from narwhals._compliant.typing import EagerExprT
1114
from narwhals._compliant.typing import EagerSeriesT_co
1215
from narwhals.utils import deprecated
1316

1417
if TYPE_CHECKING:
15-
from narwhals._compliant.expr import CompliantExpr
1618
from narwhals._compliant.selectors import CompliantSelectorNamespace
1719
from narwhals.dtypes import DType
1820

1921
__all__ = ["CompliantNamespace", "EagerNamespace"]
2022

2123

22-
class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
23-
def col(
24-
self, *column_names: str
25-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
26-
def lit(
27-
self, value: Any, dtype: DType | None
28-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
24+
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
25+
def col(self, *column_names: str) -> CompliantExprT: ...
26+
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
27+
def exclude(self, excluded_names: Container[str]) -> CompliantExprT: ...
28+
def nth(self, *column_indices: int) -> CompliantExprT: ...
29+
def len(self) -> CompliantExprT: ...
30+
def all(self) -> CompliantExprT: ...
31+
def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
32+
def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
33+
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
34+
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
35+
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
36+
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
37+
def concat(
38+
self,
39+
items: Iterable[CompliantFrameT],
40+
*,
41+
how: Literal["horizontal", "vertical", "diagonal"],
42+
) -> CompliantFrameT: ...
43+
def when(self, predicate: CompliantExprT) -> Any: ...
44+
def concat_str(
45+
self,
46+
*exprs: CompliantExprT,
47+
separator: str,
48+
ignore_nulls: bool,
49+
) -> CompliantExprT: ...
2950
@property
3051
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
52+
@property
53+
def _expr(self) -> type[CompliantExprT]: ...
3154

3255

3356
class EagerNamespace(
34-
CompliantNamespace[EagerDataFrameT, EagerSeriesT_co],
57+
CompliantNamespace[EagerDataFrameT, EagerExprT],
3558
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
3659
):
3760
@property

narwhals/_compliant/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]")
4343
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame")
4444
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
45+
CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]")
4546

4647
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]")
4748
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")

narwhals/_dask/namespace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,17 @@
4242
import dask_expr as dx
4343

4444

45-
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]):
45+
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "DaskExpr"]):
4646
_implementation: Implementation = Implementation.DASK
4747

4848
@property
4949
def selectors(self: Self) -> DaskSelectorNamespace:
5050
return DaskSelectorNamespace(self)
5151

52+
@property
53+
def _expr(self) -> type[DaskExpr]:
54+
return DaskExpr
55+
5256
def __init__(
5357
self: Self, *, backend_version: tuple[int, ...], version: Version
5458
) -> None:

narwhals/_duckdb/namespace.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any
88
from typing import Callable
99
from typing import Container
10+
from typing import Iterable
1011
from typing import Literal
1112
from typing import Sequence
1213

@@ -38,7 +39,7 @@
3839
from narwhals.utils import Version
3940

4041

41-
class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]):
42+
class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "DuckDBExpr"]):
4243
_implementation: Implementation = Implementation.DUCKDB
4344

4445
def __init__(
@@ -51,6 +52,10 @@ def __init__(
5152
def selectors(self: Self) -> DuckDBSelectorNamespace:
5253
return DuckDBSelectorNamespace(self)
5354

55+
@property
56+
def _expr(self) -> type[DuckDBExpr]:
57+
return DuckDBExpr
58+
5459
def all(self: Self) -> DuckDBExpr:
5560
return DuckDBExpr.from_column_names(
5661
get_column_names,
@@ -61,7 +66,7 @@ def all(self: Self) -> DuckDBExpr:
6166

6267
def concat(
6368
self: Self,
64-
items: Sequence[DuckDBLazyFrame],
69+
items: Iterable[DuckDBLazyFrame],
6570
*,
6671
how: Literal["horizontal", "vertical", "diagonal"],
6772
) -> DuckDBLazyFrame:
@@ -71,6 +76,7 @@ def concat(
7176
if how == "diagonal":
7277
msg = "Not implemented yet"
7378
raise NotImplementedError(msg)
79+
items = list(items)
7480
first = items[0]
7581
schema = first.schema
7682
if how == "vertical" and not all(x.schema == schema for x in items[1:]):

narwhals/_expression_parsing.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from narwhals._compliant import CompliantExpr
2727
from narwhals._compliant import CompliantFrameT
2828
from narwhals._compliant import CompliantNamespace
29-
from narwhals._compliant import CompliantSeriesOrNativeExprT_co
3029
from narwhals.expr import Expr
3130
from narwhals.typing import CompliantDataFrame
3231
from narwhals.typing import CompliantLazyFrame
@@ -91,11 +90,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:
9190

9291

9392
def extract_compliant(
94-
plx: CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co],
93+
plx: CompliantNamespace[CompliantFrameT, Any],
9594
other: Any,
9695
*,
9796
str_as_lit: bool,
98-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | object:
97+
) -> CompliantExpr[CompliantFrameT, Any] | object:
9998
if is_expr(other):
10099
return other._to_compliant_expr(plx)
101100
if isinstance(other, str) and not str_as_lit:

narwhals/_spark_like/namespace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from narwhals.utils import Version
3333

3434

35-
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]):
35+
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "SparkLikeExpr"]):
3636
def __init__(
3737
self: Self,
3838
*,
@@ -48,6 +48,10 @@ def __init__(
4848
def selectors(self: Self) -> SparkLikeSelectorNamespace:
4949
return SparkLikeSelectorNamespace(self)
5050

51+
@property
52+
def _expr(self) -> type[SparkLikeExpr]:
53+
return SparkLikeExpr
54+
5155
def all(self: Self) -> SparkLikeExpr:
5256
return SparkLikeExpr.from_column_names(
5357
get_column_names,

0 commit comments

Comments
 (0)