Skip to content

Commit 24bdf86

Browse files
authored
fix(typing): Narrow TypeVar(s) used in (Data|Lazy)Frame (#2356)
* fix(typing): Narrow `IntoDataFrame` Will close #2344 * fix(typing): Remove `DataFrame` from `IntoFrame` * fix(typing): Narrow `IntoLazyFrame`, `IntoFrame` - Revealed two new `[overload-cannot-match]` from `mypy` - I agreed with that and removed the conflict sources Will close #2345 * fix(typing): Annotate `DataFrame._compliant_frame` Revealed quite a few other issues * chore: Add missing `CompliantDataFrame.pivot` + fix related quirks * fix(typing): Ensure `__iter__` is available on group_by * chore(typing): Fix most of `DataFrame` * chore(typing): Ignore interchange `[type-var]` * test(typing): Barely fix dodgy spark typing - I think this whole test needs rewriting - We shouldn't be depending on the internals like this * fix: Implement `to_numpy` to catch args https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.to_numpy.html * fix(typing): Annotate `LazyFrame._compliant_frame` * chore(typing): Ignore and add note for `spark_like` cast * chore(typing): Partial `v1` backport Spent waaaaaay too long trying to get this working * fix(typing): Just preserve `v1` behavior #2356 (comment) * simplify * try old `Union` #2356 (comment) * docs(typing): Provide more context on what and why Expanded on (#2356 (comment)) * chore(typing): Use `Sequence[str]` in `pivot` https://github.com/narwhals-dev/narwhals/pull/2356/files/5dd782522f23ed2aef3554a2aa89fc9903abd094#r2040702116 * refactor(typing): Use `PivotAgg` #2352
1 parent 78f27af commit 24bdf86

File tree

13 files changed

+228
-101
lines changed

13 files changed

+228
-101
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,3 +845,5 @@ def unpivot(
845845
)
846846
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
847847
# upcast numeric to non-numeric (e.g. string) datatypes
848+
849+
pivot = not_implemented()

narwhals/_compliant/dataframe.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@
3737
from typing_extensions import TypeAlias
3838

3939
from narwhals._compliant.group_by import CompliantGroupBy
40+
from narwhals._compliant.group_by import DataFrameGroupBy
4041
from narwhals._translate import IntoArrowTable
4142
from narwhals.dtypes import DType
4243
from narwhals.schema import Schema
4344
from narwhals.typing import AsofJoinStrategy
4445
from narwhals.typing import JoinStrategy
4546
from narwhals.typing import LazyUniqueKeepStrategy
47+
from narwhals.typing import PivotAgg
4648
from narwhals.typing import SizeUnit
4749
from narwhals.typing import UniqueKeepStrategy
4850
from narwhals.typing import _2DArray
@@ -136,7 +138,7 @@ def gather_every(self, n: int, offset: int) -> Self: ...
136138
def get_column(self, name: str) -> CompliantSeriesT: ...
137139
def group_by(
138140
self, *keys: str, drop_null_keys: bool
139-
) -> CompliantGroupBy[Self, Any]: ...
141+
) -> DataFrameGroupBy[Self, Any]: ...
140142
def head(self, n: int) -> Self: ...
141143
def item(self, row: int | None, column: int | str | None) -> Any: ...
142144
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
@@ -165,6 +167,16 @@ def join_asof(
165167
suffix: str,
166168
) -> Self: ...
167169
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrame[Any, Any]: ...
170+
def pivot(
171+
self,
172+
on: Sequence[str],
173+
*,
174+
index: Sequence[str] | None,
175+
values: Sequence[str] | None,
176+
aggregate_function: PivotAgg | None,
177+
sort_columns: bool,
178+
separator: str,
179+
) -> Self: ...
168180
def rename(self, mapping: Mapping[str, str]) -> Self: ...
169181
def row(self, index: int) -> tuple[Any, ...]: ...
170182
def rows(

narwhals/_compliant/group_by.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def __init__(
7676
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
7777

7878

79+
class DataFrameGroupBy(
80+
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
81+
Protocol38[CompliantDataFrameT_co, CompliantExprT_contra],
82+
):
83+
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
84+
85+
7986
class DepthTrackingGroupBy(
8087
CompliantGroupBy[CompliantFrameT_co, DepthTrackingExprT_contra],
8188
Protocol38[CompliantFrameT_co, DepthTrackingExprT_contra, NativeAggregationT_co],
@@ -132,9 +139,9 @@ def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
132139

133140
class EagerGroupBy(
134141
DepthTrackingGroupBy[CompliantDataFrameT_co, EagerExprT_contra, str],
142+
DataFrameGroupBy[CompliantDataFrameT_co, EagerExprT_contra],
135143
Protocol38[CompliantDataFrameT_co, EagerExprT_contra],
136-
):
137-
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
144+
): ...
138145

139146

140147
class LazyGroupBy(

narwhals/_pandas_like/dataframe.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from narwhals.typing import CompliantLazyFrame
6666
from narwhals.typing import DTypeBackend
6767
from narwhals.typing import JoinStrategy
68+
from narwhals.typing import PivotAgg
6869
from narwhals.typing import SizeUnit
6970
from narwhals.typing import UniqueKeepStrategy
7071
from narwhals.typing import _1DArray
@@ -1017,12 +1018,12 @@ def gather_every(self: Self, n: int, offset: int) -> Self:
10171018
return self._with_native(self.native.iloc[offset::n], validate_column_names=False)
10181019

10191020
def pivot(
1020-
self: Self,
1021-
on: list[str],
1021+
self,
1022+
on: Sequence[str],
10221023
*,
1023-
index: list[str] | None,
1024-
values: list[str] | None,
1025-
aggregate_function: Any | None,
1024+
index: Sequence[str] | None,
1025+
values: Sequence[str] | None,
1026+
aggregate_function: PivotAgg | None,
10261027
sort_columns: bool,
10271028
separator: str,
10281029
) -> Self:

narwhals/_pandas_like/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from contextlib import suppress
66
from typing import TYPE_CHECKING
77
from typing import Any
8+
from typing import Sequence
89
from typing import Sized
910
from typing import TypeVar
1011
from typing import cast
@@ -652,9 +653,9 @@ def select_columns_by_name(
652653

653654
def pivot_table(
654655
df: PandasLikeDataFrame,
655-
values: list[str],
656-
index: list[str],
657-
columns: list[str],
656+
values: Sequence[str],
657+
index: Sequence[str],
658+
columns: Sequence[str],
658659
aggregate_function: str | None,
659660
) -> Any:
660661
dtypes = import_dtypes_module(df._version)

narwhals/_polars/dataframe.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from narwhals.typing import CompliantDataFrame
4545
from narwhals.typing import CompliantLazyFrame
4646
from narwhals.typing import JoinStrategy
47+
from narwhals.typing import PivotAgg
4748
from narwhals.typing import _2DArray
4849
from narwhals.utils import Version
4950
from narwhals.utils import _FullContext
@@ -77,7 +78,6 @@ class PolarsDataFrame:
7778
select: Method[Self]
7879
sort: Method[Self]
7980
to_arrow: Method[pa.Table]
80-
to_numpy: Method[_2DArray]
8181
to_pandas: Method[pd.DataFrame]
8282
unique: Method[Self]
8383
with_columns: Method[Self]
@@ -232,6 +232,9 @@ def __array__(
232232
return self.native.__array__(dtype)
233233
return self.native.__array__(dtype)
234234

235+
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
236+
return self.native.to_numpy()
237+
235238
def collect_schema(self: Self) -> dict[str, DType]:
236239
if self._backend_version < (1,):
237240
return {
@@ -412,15 +415,12 @@ def unpivot(
412415
)
413416

414417
def pivot(
415-
self: Self,
416-
on: list[str],
418+
self,
419+
on: Sequence[str],
417420
*,
418-
index: list[str] | None,
419-
values: list[str] | None,
420-
aggregate_function: Literal[
421-
"min", "max", "first", "last", "sum", "mean", "median", "len"
422-
]
423-
| None,
421+
index: Sequence[str] | None,
422+
values: Sequence[str] | None,
423+
aggregate_function: PivotAgg | None,
424424
sort_columns: bool,
425425
separator: str,
426426
) -> Self:

0 commit comments

Comments
 (0)