Skip to content

Commit 2869349

Browse files
authored
feat(typing): Make Implementation less opaque (#3016)
1 parent 140b961 commit 2869349

File tree

14 files changed

+488
-102
lines changed

14 files changed

+488
-102
lines changed

.github/workflows/typing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
run: uv venv .venv
3232
- name: install-reqs
3333
# TODO: add more dependencies/backends incrementally
34-
run: uv pip install -e ".[pyspark]" --group core --group typing
34+
run: uv pip install -e ".[pyspark]" --group core --group typing-ci
3535
- name: show-deps
3636
run: uv pip freeze
3737
- name: Run mypy and pyright

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ help: ## Display this help screen
2020

2121
.PHONY: typing
2222
typing: ## Run typing checks
23-
$(VENV_BIN)/uv pip install -e . --group typing
23+
$(VENV_BIN)/uv pip install -e . --group typing-ci
2424
$(VENV_BIN)/pyright
2525
$(VENV_BIN)/mypy

docs/api-reference/dataframe.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@
5858
- write_parquet
5959
show_source: false
6060
show_bases: false
61+
inherited_members: true

docs/api-reference/lazyframe.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@
3434
show_root_heading: false
3535
show_source: false
3636
show_bases: false
37+
inherited_members: true

narwhals/_namespace.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
Any,
88
Callable,
99
Generic,
10-
Literal,
1110
Protocol,
1211
TypeVar,
12+
cast,
1313
overload,
1414
)
1515

@@ -37,8 +37,6 @@
3737
import pandas as pd
3838
import polars as pl
3939
import pyarrow as pa
40-
import pyspark.sql as pyspark_sql
41-
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
4240
from typing_extensions import Self, TypeAlias, TypeIs
4341

4442
from narwhals._arrow.namespace import ArrowNamespace
@@ -68,30 +66,33 @@
6866
_Guard: TypeAlias = "Callable[[Any], TypeIs[T]]"
6967

7068
EagerAllowedNamespace: TypeAlias = "Namespace[PandasLikeNamespace] | Namespace[ArrowNamespace] | Namespace[PolarsNamespace]"
69+
Incomplete: TypeAlias = Any
7170

7271
class _BasePandasLike(Sized, Protocol):
7372
index: Any
7473
"""`mypy` doesn't like the asymmetric `property` setter in `pandas`."""
7574

7675
def __getitem__(self, key: Any, /) -> Any: ...
77-
def __mul__(self, other: float | Collection[float] | Self) -> Self: ...
78-
def __floordiv__(self, other: float | Collection[float] | Self) -> Self: ...
76+
def __mul__(self, other: float | Collection[float] | Self, /) -> Self: ...
77+
def __floordiv__(self, other: float | Collection[float] | Self, /) -> Self: ...
7978
@property
8079
def loc(self) -> Any: ...
8180
@property
8281
def shape(self) -> tuple[int, ...]: ...
8382
def set_axis(self, labels: Any, *, axis: Any = ..., copy: bool = ...) -> Self: ...
8483
def copy(self, deep: bool = ...) -> Self: ... # noqa: FBT001
85-
def rename(self, *args: Any, inplace: Literal[False], **kwds: Any) -> Self:
86-
"""`inplace=False` is required to avoid (incorrect?) default overloads."""
87-
...
84+
def rename(self, *args: Any, **kwds: Any) -> Self | Incomplete:
85+
"""`mypy` & `pyright` disagree on overloads.
86+
87+
`Incomplete` used to fix [more important issue](https://github.com/narwhals-dev/narwhals/pull/3016#discussion_r2296139744).
88+
"""
8889

8990
class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): ...
9091

9192
class _BasePandasLikeSeries(NativeSeries, _BasePandasLike, Protocol):
92-
def where(self, cond: Any, other: Any = ..., **kwds: Any) -> Any: ...
93+
def where(self, cond: Any, other: Any = ..., /) -> Self | Incomplete: ...
9394

94-
class _NativeDask(Protocol):
95+
class _NativeDask(NativeLazyFrame, Protocol):
9596
_partition_type: type[pd.DataFrame]
9697

9798
class _CuDFDataFrame(_BasePandasLikeFrame, Protocol):
@@ -112,6 +113,12 @@ class _ModinDataFrame(_BasePandasLikeFrame, Protocol):
112113
class _ModinSeries(_BasePandasLikeSeries, Protocol):
113114
_pandas_class: type[pd.Series[Any]]
114115

116+
# NOTE: Using `pyspark.sql.DataFrame` creates false positives in overloads when not installed
117+
class _PySparkDataFrame(NativeLazyFrame, Protocol):
118+
# Arbitrary method that `sqlframe` doesn't have and unlikely to appear anywhere else
119+
# https://github.com/apache/spark/blob/8530444e25b83971da4314c608aa7d763adeceb3/python/pyspark/sql/dataframe.py#L4875
120+
def dropDuplicatesWithinWatermark(self, *arg: Any, **kwargs: Any) -> Any: ... # noqa: N802
121+
115122
_NativePolars: TypeAlias = "pl.DataFrame | pl.LazyFrame | pl.Series"
116123
_NativeArrow: TypeAlias = "pa.Table | pa.ChunkedArray[Any]"
117124
_NativeDuckDB: TypeAlias = "duckdb.DuckDBPyRelation"
@@ -124,8 +131,8 @@ class _ModinSeries(_BasePandasLikeSeries, Protocol):
124131
)
125132
_NativePandasLike: TypeAlias = "_NativePandasLikeDataFrame |_NativePandasLikeSeries"
126133
_NativeSQLFrame: TypeAlias = "SQLFrameDataFrame"
127-
_NativePySpark: TypeAlias = "pyspark_sql.DataFrame"
128-
_NativePySparkConnect: TypeAlias = "PySparkConnectDataFrame"
134+
_NativePySpark: TypeAlias = _PySparkDataFrame
135+
_NativePySparkConnect: TypeAlias = _PySparkDataFrame
129136
_NativeSparkLike: TypeAlias = (
130137
"_NativeSQLFrame | _NativePySpark | _NativePySparkConnect"
131138
)
@@ -371,8 +378,10 @@ def is_native_dask(obj: Any) -> TypeIs[_NativeDask]:
371378

372379
is_native_duckdb: _Guard[_NativeDuckDB] = is_duckdb_relation
373380
is_native_sqlframe: _Guard[_NativeSQLFrame] = is_sqlframe_dataframe
374-
is_native_pyspark: _Guard[_NativePySpark] = is_pyspark_dataframe
375-
is_native_pyspark_connect: _Guard[_NativePySparkConnect] = is_pyspark_connect_dataframe
381+
is_native_pyspark = cast("_Guard[_NativePySpark]", is_pyspark_dataframe)
382+
is_native_pyspark_connect = cast(
383+
"_Guard[_NativePySparkConnect]", is_pyspark_connect_dataframe
384+
)
376385

377386

378387
def is_native_pandas(obj: Any) -> TypeIs[_NativePandas]:

narwhals/_pandas_like/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import functools
44
import operator
55
import re
6-
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
6+
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast
77

88
import pandas as pd
99

@@ -202,8 +202,10 @@ def rename(
202202
if implementation is Implementation.PANDAS and (
203203
implementation._backend_version() >= (3,)
204204
): # pragma: no cover
205-
return obj.rename(*args, **kwargs, inplace=False)
206-
return obj.rename(*args, **kwargs, copy=False, inplace=False)
205+
result = obj.rename(*args, **kwargs, inplace=False)
206+
else:
207+
result = obj.rename(*args, **kwargs, copy=False, inplace=False)
208+
return cast("NativeNDFrameT", result) # type: ignore[redundant-cast]
207209

208210

209211
@functools.lru_cache(maxsize=16)

narwhals/_utils.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,40 @@
7373
NativeSeriesT_co,
7474
)
7575
from narwhals._compliant.typing import EvalNames, NativeDataFrameT, NativeLazyFrameT
76-
from narwhals._namespace import Namespace
76+
from narwhals._namespace import (
77+
Namespace,
78+
_NativeArrow,
79+
_NativeCuDF,
80+
_NativeDask,
81+
_NativeDuckDB,
82+
_NativeIbis,
83+
_NativeModin,
84+
_NativePandas,
85+
_NativePandasLike,
86+
_NativePolars,
87+
_NativePySpark,
88+
_NativePySparkConnect,
89+
_NativeSQLFrame,
90+
)
7791
from narwhals._translate import ArrowStreamExportable, IntoArrowTable, ToNarwhalsT_co
7892
from narwhals._typing import (
7993
Backend,
8094
IntoBackend,
95+
_ArrowImpl,
96+
_CudfImpl,
97+
_DaskImpl,
98+
_DuckDBImpl,
8199
_EagerAllowedImpl,
100+
_IbisImpl,
82101
_LazyAllowedImpl,
83102
_LazyFrameCollectImpl,
103+
_ModinImpl,
104+
_PandasImpl,
105+
_PandasLikeImpl,
106+
_PolarsImpl,
107+
_PySparkConnectImpl,
108+
_PySparkImpl,
109+
_SQLFrameImpl,
84110
)
85111
from narwhals.dataframe import DataFrame, LazyFrame
86112
from narwhals.dtypes import DType
@@ -141,7 +167,7 @@ def columns(self) -> Sequence[str]: ...
141167
_Constructor: TypeAlias = "Callable[Concatenate[_T, P], R2]"
142168

143169

144-
class _StoresNative(Protocol[NativeT_co]): # noqa: PYI046
170+
class _StoresNative(Protocol[NativeT_co]):
145171
"""Provides access to a native object.
146172
147173
Native objects have types like:
@@ -2034,3 +2060,91 @@ def deep_attrgetter(attr: str, *nested: str) -> attrgetter[Any]:
20342060
def deep_getattr(obj: Any, name_1: str, *nested: str) -> Any:
20352061
"""Perform a nested attribute lookup on `obj`."""
20362062
return deep_attrgetter(name_1, *nested)(obj)
2063+
2064+
2065+
class Compliant(
2066+
_StoresNative[NativeT_co], _StoresImplementation, Protocol[NativeT_co]
2067+
): ...
2068+
2069+
2070+
class Narwhals(Protocol[NativeT_co]):
2071+
"""Minimal *Narwhals-level* protocol.
2072+
2073+
Provides access to a compliant object:
2074+
2075+
obj: Narwhals[NativeT_co]]
2076+
compliant: Compliant[NativeT_co] = obj._compliant
2077+
2078+
Which itself exposes:
2079+
2080+
implementation: Implementation = compliant.implementation
2081+
native: NativeT_co = compliant.native
2082+
2083+
This interface is used for revealing which `Implementation` member is associated with **either**:
2084+
- One or more [nominal] native type(s)
2085+
- One or more [structural] type(s)
2086+
- where the true native type(s) are [assignable to] *at least* one of them
2087+
2088+
These relationships are defined in the `@overload`s of `_Implementation.__get__(...)`.
2089+
2090+
[nominal]: https://typing.python.org/en/latest/spec/glossary.html#term-nominal
2091+
[structural]: https://typing.python.org/en/latest/spec/glossary.html#term-structural
2092+
[assignable to]: https://typing.python.org/en/latest/spec/glossary.html#term-assignable
2093+
"""
2094+
2095+
@property
2096+
def _compliant(self) -> Compliant[NativeT_co]: ...
2097+
2098+
2099+
class _Implementation:
2100+
"""Descriptor for matching an opaque `Implementation` on a generic class.
2101+
2102+
Based on [pyright comment](https://github.com/microsoft/pyright/issues/3071#issuecomment-1043978070)
2103+
"""
2104+
2105+
def __set_name__(self, owner: type[Any], name: str) -> None:
2106+
self.__name__: str = name
2107+
2108+
@overload
2109+
def __get__(self, instance: Narwhals[_NativePolars], owner: Any) -> _PolarsImpl: ...
2110+
@overload
2111+
def __get__(self, instance: Narwhals[_NativePandas], owner: Any) -> _PandasImpl: ...
2112+
@overload
2113+
def __get__(self, instance: Narwhals[_NativeModin], owner: Any) -> _ModinImpl: ...
2114+
@overload # TODO @dangotbanned: Rename `_typing` `*Cudf*` aliases to `*CuDF*`
2115+
def __get__(self, instance: Narwhals[_NativeCuDF], owner: Any) -> _CudfImpl: ...
2116+
@overload
2117+
def __get__(
2118+
self, instance: Narwhals[_NativePandasLike], owner: Any
2119+
) -> _PandasLikeImpl: ...
2120+
@overload
2121+
def __get__(self, instance: Narwhals[_NativeArrow], owner: Any) -> _ArrowImpl: ...
2122+
@overload
2123+
def __get__(
2124+
self, instance: Narwhals[_NativePolars | _NativeArrow | _NativePandas], owner: Any
2125+
) -> _PolarsImpl | _PandasImpl | _ArrowImpl: ...
2126+
@overload
2127+
def __get__(self, instance: Narwhals[_NativeDuckDB], owner: Any) -> _DuckDBImpl: ...
2128+
@overload
2129+
def __get__(
2130+
self, instance: Narwhals[_NativeSQLFrame], owner: Any
2131+
) -> _SQLFrameImpl: ...
2132+
@overload
2133+
def __get__(self, instance: Narwhals[_NativeDask], owner: Any) -> _DaskImpl: ...
2134+
@overload
2135+
def __get__(self, instance: Narwhals[_NativeIbis], owner: Any) -> _IbisImpl: ...
2136+
@overload
2137+
def __get__(
2138+
self, instance: Narwhals[_NativePySpark | _NativePySparkConnect], owner: Any
2139+
) -> _PySparkImpl | _PySparkConnectImpl: ...
2140+
# NOTE: https://docs.python.org/3/howto/descriptor.html#invocation-from-a-class
2141+
@overload
2142+
def __get__(self, instance: None, owner: type[Narwhals[Any]]) -> Self: ...
2143+
@overload
2144+
def __get__(
2145+
self, instance: DataFrame[Any] | Series[Any], owner: Any
2146+
) -> _EagerAllowedImpl: ...
2147+
@overload
2148+
def __get__(self, instance: LazyFrame[Any], owner: Any) -> _LazyAllowedImpl: ...
2149+
def __get__(self, instance: Narwhals[Any] | None, owner: Any) -> Any:
2150+
return self if instance is None else instance._compliant._implementation

0 commit comments

Comments
 (0)