Skip to content

Commit 6a5ed1d

Browse files
committed
Merge remote-tracking branch 'upstream/main' into series-from-numpy
2 parents adb6b7a + 14c0c78 commit 6a5ed1d

File tree

28 files changed

+248
-337
lines changed

28 files changed

+248
-337
lines changed

docs/api-reference/lazyframe.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
handler: python
55
options:
66
members:
7-
- clone
87
- collect
98
- collect_schema
109
- columns

narwhals/_arrow/namespace.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

33
import operator
4-
from functools import partial
54
from functools import reduce
65
from itertools import chain
76
from typing import TYPE_CHECKING
87
from typing import Any
98
from typing import Callable
10-
from typing import Container
119
from typing import Iterable
1210
from typing import Literal
1311
from typing import Sequence
@@ -29,10 +27,7 @@
2927
from narwhals._expression_parsing import combine_alias_output_names
3028
from narwhals._expression_parsing import combine_evaluate_output_names
3129
from narwhals.utils import Implementation
32-
from narwhals.utils import exclude_column_names
33-
from narwhals.utils import get_column_names
3430
from narwhals.utils import import_dtypes_module
35-
from narwhals.utils import passthrough_column_names
3631

3732
if TYPE_CHECKING:
3833
from typing import Callable
@@ -41,7 +36,6 @@
4136
from typing_extensions import TypeAlias
4237

4338
from narwhals._arrow.typing import Incomplete
44-
from narwhals._arrow.typing import IntoArrowExpr
4539
from narwhals.dtypes import DType
4640
from narwhals.utils import Version
4741

@@ -66,20 +60,6 @@ def __init__(
6660
self._version = version
6761

6862
# --- selection ---
69-
def col(self: Self, *column_names: str) -> ArrowExpr:
70-
return self._expr.from_column_names(
71-
passthrough_column_names(column_names), function_name="col", context=self
72-
)
73-
74-
def exclude(self: Self, excluded_names: Container[str]) -> ArrowExpr:
75-
return self._expr.from_column_names(
76-
partial(exclude_column_names, names=excluded_names),
77-
function_name="exclude",
78-
context=self,
79-
)
80-
81-
def nth(self: Self, *column_indices: int) -> ArrowExpr:
82-
return self._expr.from_column_indices(*column_indices, context=self)
8363

8464
def len(self: Self) -> ArrowExpr:
8565
# coverage bug? this is definitely hit
@@ -97,11 +77,6 @@ def len(self: Self) -> ArrowExpr:
9777
version=self._version,
9878
)
9979

100-
def all(self: Self) -> ArrowExpr:
101-
return self._expr.from_column_names(
102-
get_column_names, function_name="all", context=self
103-
)
104-
10580
def lit(self: Self, value: Any, dtype: DType | None) -> ArrowExpr:
10681
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
10782
arrow_series = ArrowSeries._from_iterable(
@@ -164,7 +139,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
164139
context=self,
165140
)
166141

167-
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> IntoArrowExpr:
142+
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
168143
dtypes = import_dtypes_module(self._version)
169144

170145
def func(df: ArrowDataFrame) -> list[ArrowSeries]:

narwhals/_compliant/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from narwhals._compliant.selectors import LazySelectorNamespace
1717
from narwhals._compliant.series import CompliantSeries
1818
from narwhals._compliant.series import EagerSeries
19+
from narwhals._compliant.typing import CompliantExprT
1920
from narwhals._compliant.typing import CompliantFrameT
2021
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
2122
from narwhals._compliant.typing import CompliantSeriesT_co
@@ -26,6 +27,7 @@
2627
__all__ = [
2728
"CompliantDataFrame",
2829
"CompliantExpr",
30+
"CompliantExprT",
2931
"CompliantFrameT",
3032
"CompliantLazyFrame",
3133
"CompliantNamespace",

narwhals/_compliant/expr.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,19 @@ def __call__(
9191
def __narwhals_expr__(self) -> None: ...
9292
def __narwhals_namespace__(
9393
self,
94-
) -> CompliantNamespace[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
94+
) -> CompliantNamespace[CompliantFrameT, Self]: ...
95+
@classmethod
96+
def from_column_names(
97+
cls,
98+
evaluate_column_names: Callable[[CompliantFrameT], Sequence[str]],
99+
/,
100+
*,
101+
function_name: str,
102+
context: _FullContext,
103+
) -> Self: ...
104+
@classmethod
105+
def from_column_indices(cls, *column_indices: int, context: _FullContext) -> Self: ...
106+
95107
def is_null(self) -> Self: ...
96108
def abs(self) -> Self: ...
97109
def all(self) -> Self: ...
@@ -332,22 +344,6 @@ def _from_series(cls, series: EagerSeriesT) -> Self:
332344
version=series._version,
333345
)
334346

335-
@classmethod
336-
def from_column_names(
337-
cls,
338-
evaluate_column_names: Callable[[EagerDataFrameT], Sequence[str]],
339-
/,
340-
*,
341-
function_name: str,
342-
context: _FullContext,
343-
) -> Self: ...
344-
@classmethod
345-
def from_column_indices(
346-
cls,
347-
*column_indices: int,
348-
context: _FullContext,
349-
) -> Self: ...
350-
351347
def _reuse_series(
352348
self: Self,
353349
method_name: str,

narwhals/_compliant/namespace.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from __future__ import annotations
22

3+
from functools import partial
34
from typing import TYPE_CHECKING
45
from typing import Any
6+
from typing import Container
7+
from typing import Iterable
8+
from typing import Literal
59
from typing import Protocol
610

11+
from narwhals._compliant.typing import CompliantExprT
712
from narwhals._compliant.typing import CompliantFrameT
8-
from narwhals._compliant.typing import CompliantSeriesOrNativeExprT_co
913
from narwhals._compliant.typing import EagerDataFrameT
1014
from narwhals._compliant.typing import EagerExprT
1115
from narwhals._compliant.typing import EagerSeriesT_co
16+
from narwhals.utils import exclude_column_names
17+
from narwhals.utils import get_column_names
18+
from narwhals.utils import passthrough_column_names
1219

1320
if TYPE_CHECKING:
14-
from narwhals._compliant.expr import CompliantExpr
1521
from narwhals._compliant.selectors import CompliantSelectorNamespace
1622
from narwhals.dtypes import DType
1723
from narwhals.utils import Implementation
@@ -20,27 +26,61 @@
2026
__all__ = ["CompliantNamespace", "EagerNamespace"]
2127

2228

23-
class CompliantNamespace(Protocol[CompliantFrameT, CompliantSeriesOrNativeExprT_co]):
24-
def col(
25-
self, *column_names: str
26-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
27-
def lit(
28-
self, value: Any, dtype: DType | None
29-
) -> CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co]: ...
29+
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
30+
_implementation: Implementation
31+
_backend_version: tuple[int, ...]
32+
_version: Version
33+
34+
def all(self) -> CompliantExprT:
35+
return self._expr.from_column_names(
36+
get_column_names, function_name="all", context=self
37+
)
38+
39+
def col(self, *column_names: str) -> CompliantExprT:
40+
return self._expr.from_column_names(
41+
passthrough_column_names(column_names), function_name="col", context=self
42+
)
43+
44+
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
45+
return self._expr.from_column_names(
46+
partial(exclude_column_names, names=excluded_names),
47+
function_name="exclude",
48+
context=self,
49+
)
50+
51+
def nth(self, *column_indices: int) -> CompliantExprT:
52+
return self._expr.from_column_indices(*column_indices, context=self)
53+
54+
def len(self) -> CompliantExprT: ...
55+
def lit(self, value: Any, dtype: DType | None) -> CompliantExprT: ...
56+
def all_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
57+
def any_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
58+
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
59+
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
60+
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
61+
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
62+
def concat(
63+
self,
64+
items: Iterable[CompliantFrameT],
65+
*,
66+
how: Literal["horizontal", "vertical", "diagonal"],
67+
) -> CompliantFrameT: ...
68+
def when(self, predicate: CompliantExprT) -> Any: ...
69+
def concat_str(
70+
self,
71+
*exprs: CompliantExprT,
72+
separator: str,
73+
ignore_nulls: bool,
74+
) -> CompliantExprT: ...
3075
@property
3176
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
77+
@property
78+
def _expr(self) -> type[CompliantExprT]: ...
3279

3380

3481
class EagerNamespace(
35-
CompliantNamespace[EagerDataFrameT, EagerSeriesT_co],
82+
CompliantNamespace[EagerDataFrameT, EagerExprT],
3683
Protocol[EagerDataFrameT, EagerSeriesT_co, EagerExprT],
3784
):
38-
_implementation: Implementation
39-
_backend_version: tuple[int, ...]
40-
_version: Version
41-
42-
@property
43-
def _expr(self) -> type[EagerExprT]: ...
4485
@property
4586
def _series(self) -> type[EagerSeriesT_co]: ...
46-
def all_horizontal(self, *exprs: EagerExprT) -> EagerExprT: ...

narwhals/_compliant/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound="CompliantDataFrame[Any]")
4444
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound="CompliantLazyFrame")
4545
IntoCompliantExpr: TypeAlias = "CompliantExpr[CompliantFrameT, CompliantSeriesOrNativeExprT_co] | CompliantSeriesOrNativeExprT_co"
46+
CompliantExprT = TypeVar("CompliantExprT", bound="CompliantExpr[Any, Any]")
4647

4748
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any]")
4849
EagerSeriesT = TypeVar("EagerSeriesT", bound="EagerSeries[Any]")

narwhals/_dask/dataframe.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def collect(
9494
backend: Implementation | None,
9595
**kwargs: Any,
9696
) -> CompliantDataFrame[Any]:
97-
import pandas as pd
98-
9997
result = self._native_frame.compute(**kwargs)
10098

10199
if backend is None or backend is Implementation.PANDAS:
@@ -162,15 +160,6 @@ def aggregate(self: Self, *exprs: DaskExpr) -> Self:
162160

163161
def select(self: Self, *exprs: DaskExpr) -> Self:
164162
new_series = evaluate_exprs(self, *exprs)
165-
166-
if not new_series:
167-
# return empty dataframe, like Polars does
168-
return self._from_native_frame(
169-
dd.from_pandas(
170-
pd.DataFrame(), npartitions=self._native_frame.npartitions
171-
),
172-
)
173-
174163
df = select_columns_by_name(
175164
self._native_frame.assign(**dict(new_series)),
176165
[s[0] for s in new_series],

narwhals/_dask/expr.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from narwhals._dask.namespace import DaskNamespace
4040
from narwhals.dtypes import DType
4141
from narwhals.utils import Version
42+
from narwhals.utils import _FullContext
4243

4344

4445
class DaskExpr(LazyExpr["DaskLazyFrame", "dx.Series"]):
@@ -100,8 +101,7 @@ def from_column_names(
100101
/,
101102
*,
102103
function_name: str,
103-
backend_version: tuple[int, ...],
104-
version: Version,
104+
context: _FullContext,
105105
) -> Self:
106106
def func(df: DaskLazyFrame) -> list[dx.Series]:
107107
try:
@@ -124,16 +124,13 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
124124
function_name=function_name,
125125
evaluate_output_names=evaluate_column_names,
126126
alias_output_names=None,
127-
backend_version=backend_version,
128-
version=version,
127+
backend_version=context._backend_version,
128+
version=context._version,
129129
)
130130

131131
@classmethod
132132
def from_column_indices(
133-
cls: type[Self],
134-
*column_indices: int,
135-
backend_version: tuple[int, ...],
136-
version: Version,
133+
cls: type[Self], *column_indices: int, context: _FullContext
137134
) -> Self:
138135
def func(df: DaskLazyFrame) -> list[dx.Series]:
139136
return [
@@ -146,8 +143,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
146143
function_name="nth",
147144
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
148145
alias_output_names=None,
149-
backend_version=backend_version,
150-
version=version,
146+
backend_version=context._backend_version,
147+
version=context._version,
151148
)
152149

153150
def _from_call(

0 commit comments

Comments
 (0)