Skip to content

Commit 14c0c78

Browse files
authored
chore(typing): Fill out CompliantNamespace protocol (#2202)
1 parent 7865972 commit 14c0c78

File tree

13 files changed

+144
-257
lines changed

13 files changed

+144
-257
lines changed

narwhals/_arrow/namespace.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

33
import operator
4-
from functools import partial
54
from functools import reduce
65
from itertools import chain
76
from typing import TYPE_CHECKING
87
from typing import Any
98
from typing import Callable
10-
from typing import Container
119
from typing import Iterable
1210
from typing import Literal
1311
from typing import Sequence
@@ -29,10 +27,7 @@
2927
from narwhals._expression_parsing import combine_alias_output_names
3028
from narwhals._expression_parsing import combine_evaluate_output_names
3129
from narwhals.utils import Implementation
32-
from narwhals.utils import exclude_column_names
33-
from narwhals.utils import get_column_names
3430
from narwhals.utils import import_dtypes_module
35-
from narwhals.utils import passthrough_column_names
3631

3732
if TYPE_CHECKING:
3833
from typing import Callable
@@ -41,7 +36,6 @@
4136
from typing_extensions import TypeAlias
4237

4338
from narwhals._arrow.typing import Incomplete
44-
from narwhals._arrow.typing import IntoArrowExpr
4539
from narwhals.dtypes import DType
4640
from narwhals.utils import Version
4741

@@ -69,20 +63,6 @@ def __init__(
6963
self._version = version
7064

7165
# --- selection ---
72-
def col(self: Self, *column_names: str) -> ArrowExpr:
73-
return self._expr.from_column_names(
74-
passthrough_column_names(column_names), function_name="col", context=self
75-
)
76-
77-
def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr:
78-
return self._expr.from_column_names(
79-
partial(exclude_column_names, names=excluded_names),
80-
function_name="exclude",
81-
context=self,
82-
)
83-
84-
def nth(self: Self, *column_indices: int) -> ArrowExpr:
85-
return self._expr.from_column_indices(*column_indices, context=self)
8666

8767
def len(self: Self) -> ArrowExpr:
8868
# coverage bug? this is definitely hit
@@ -100,11 +80,6 @@ def len(self: Self) -> ArrowExpr:
10080
version=self._version,
10181
)
10282

103-
def all(self: Self) -> ArrowExpr:
104-
return self._expr.from_column_names(
105-
get_column_names, function_name="all", context=self
106-
)
107-
10883
def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
10984
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
11085
arrow_series = ArrowSeries._from_iterable(
@@ -167,7 +142,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
167142
context=self,
168143
)
169144

170-
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr:
145+
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
171146
dtypes = import_dtypes_module(self._version)
172147

173148
def func(df: ArrowDataFrame) -> list[ArrowSeries]:

narwhals/_compliant/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from narwhals._compliant.selectors import LazySelectorNamespace
1717
from narwhals._compliant.series import CompliantSeries
1818
from narwhals._compliant.series import EagerSeries
19+
from narwhals._compliant.typing import CompliantExprT
1920
from narwhals._compliant.typing import CompliantFrameT
2021
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
2122
from narwhals._compliant.typing import CompliantSeriesT_co
@@ -26,6 +27,7 @@
2627
__all__ = [
2728
"CompliantDataFrame",
2829
"CompliantExpr",
30+
"CompliantExprT",
2931
"CompliantFrameT",
3032
"CompliantLazyFrame",
3133
"CompliantNamespace",

narwhals/_compliant/expr.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,19 @@ def __call__(
8989
def __narwhals_expr__(self) -> None: ...
9090
def __narwhals_namespace__(
9191
self,
92-
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
92+
) -> CompliantNamespace[CompliantFrameT, Self]: ...
93+
@classmethod
94+
def from_column_names(
95+
cls,
96+
evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]],
97+
/,
98+
*,
99+
function_name: str,
100+
context: _FullContext,
101+
) -> Self: ...
102+
@classmethod
103+
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ...
104+
93105
def is_null(self) -> Self: ...
94106
def abs(self) -> Self: ...
95107
def all(self) -> Self: ...
@@ -330,22 +342,6 @@ def _from_series(cls, series: EagerSeriesT) -> Self:
330342
version=series._version,
331343
)
332344

333-
@classmethod
334-
def from_column_names(
335-
cls,
336-
evaluate_column_names: Callable[[EagerDataFrameT], Sequence[str]],
337-
/,
338-
*,
339-
function_name: str,
340-
context: _FullContext,
341-
) -> Self: ...
342-
@classmethod
343-
def from_column_indices(
344-
cls,
345-
*column_indices: int,
346-
context: _FullContext,
347-
) -> Self: ...
348-
349345
def _reuse_series(
350346
self: Self,
351347
method_name: str,

narwhals/_compliant/namespace.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,90 @@
11
from __future__ import annotations
22

3+
from functools import partial
34
from typing import TYPE_CHECKING
45
from typing import Any
6+
from typing import Container
7+
from typing import Iterable
8+
from typing import Literal
59
from typing import Protocol
610

11+
from narwhals._compliant.typing import CompliantExprT
712
from narwhals._compliant.typing import CompliantFrameT
8-
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
913
from narwhals._compliant.typing import EagerDataFrameT
1014
from narwhals._compliant.typing import EagerExprT
1115
from narwhals._compliant.typing import EagerSeriesT_co
1216
from narwhals.utils import deprecated
17+
from narwhals.utils import exclude_column_names
18+
from narwhals.utils import get_column_names
19+
from narwhals.utils import passthrough_column_names
1320

1421
if TYPE_CHECKING:
15-
from narwhals._compliant.expr import CompliantExpr
1622
from narwhals._compliant.selectors import CompliantSelectorNamespace
1723
from narwhals.dtypes import DType
24+
from narwhals.utils import Implementation
25+
from narwhals.utils import Version
1826

1927
__all__ = ["CompliantNamespace", "EagerNamespace"]
2028

2129

22-
class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
23-
def col(
24-
self, *column_names: str
25-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
26-
def lit(
27-
self, value: Any, dtype: DType | None
28-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
30+
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
31+
_implementation: Implementation
32+
_backend_version: tuple[int, ...]
33+
_version: Version
34+
35+
def all(self) -> CompliantExprT:
36+
return self._expr.from_column_names(
37+
get_column_names, function_name="all", context=self
38+
)
39+
40+
def col(self, *column_names: str) -> CompliantExprT:
41+
return self._expr.from_column_names(
42+
passthrough_column_names(column_names), function_name="col", context=self
43+
)
44+
45+
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
46+
return self._expr.from_column_names(
47+
partial(exclude_column_names, names=excluded_names),
48+
function_name="exclude",
49+
context=self,
50+
)
51+
52+
def nth(self, *column_indices: int) -> CompliantExprT:
53+
return self._expr.from_column_indices(*column_indices, context=self)
54+
55+
def len(self) -> CompliantExprT: ...
56+
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
57+
def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
58+
def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
59+
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
60+
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
61+
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
62+
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
63+
def concat(
64+
self,
65+
items: Iterable[CompliantFrameT],
66+
*,
67+
how: Literal["horizontal", "vertical", "diagonal"],
68+
) -> CompliantFrameT: ...
69+
def when(self, predicate: CompliantExprT) -> Any: ...
70+
def concat_str(
71+
self,
72+
*exprs: CompliantExprT,
73+
separator: str,
74+
ignore_nulls: bool,
75+
) -> CompliantExprT: ...
2976
@property
3077
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
78+
@property
79+
def _expr(self) -> type[CompliantExprT]: ...
3180

3281

3382
class EagerNamespace(
34-
CompliantNamespace[EagerDataFrameT, EagerSeriesT_co],
83+
CompliantNamespace[EagerDataFrameT, EagerExprT],
3584
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
3685
):
37-
@property
38-
def _expr(self) -> type[EagerExprT]: ...
3986
@property
4087
def _series(self) -> type[EagerSeriesT_co]: ...
41-
def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ...
4288

4389
@deprecated(
4490
"Internally used for `numpy.ndarray` -> `CompliantSeries`\n"

narwhals/_compliant/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]")
4343
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame")
4444
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
45+
CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]")
4546

4647
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]")
4748
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")

narwhals/_dask/expr.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from narwhals._dask.namespace import DaskNamespace
4040
from narwhals.dtypes import DType
4141
from narwhals.utils import Version
42+
from narwhals.utils import _FullContext
4243

4344

4445
class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]):
@@ -100,8 +101,7 @@ def from_column_names(
100101
/,
101102
*,
102103
function_name: str,
103-
backend_version: tuple[int, ...],
104-
version: Version,
104+
context: _FullContext,
105105
) -> Self:
106106
def func(df: DaskLazyFrame) -> list[dx.Series]:
107107
try:
@@ -124,16 +124,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
124124
function_name=function_name,
125125
evaluate_output_names=evaluate_column_names,
126126
alias_output_names=None,
127-
backend_version=backend_version,
128-
version=version,
127+
backend_version=context._backend_version,
128+
version=context._version,
129129
)
130130

131131
@classmethod
132132
def from_column_indices(
133-
cls: type[Self],
134-
*column_indices: int,
135-
backend_version: tuple[int, ...],
136-
version: Version,
133+
cls: type[Self], *column_indices: int, context: _FullContext
137134
) -> Self:
138135
def func(df: DaskLazyFrame) -> list[dx.Series]:
139136
return [
@@ -146,8 +143,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
146143
function_name="nth",
147144
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
148145
alias_output_names=None,
149-
backend_version=backend_version,
150-
version=version,
146+
backend_version=context._backend_version,
147+
version=context._version,
151148
)
152149

153150
def _from_call(

0 commit comments

Comments
 (0)