Skip to content

Commit a211588

Browse files
committed
feat(typing): Annotate dynamic parts of PolarsDataFrame
Resolves (#2223 (comment))
1 parent 94699b9 commit a211588

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

narwhals/_polars/dataframe.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
from typing import Iterator
66
from typing import Literal
7+
from typing import Mapping
78
from typing import Sequence
89
from typing import overload
910

@@ -23,9 +24,13 @@
2324

2425
if TYPE_CHECKING:
2526
from types import ModuleType
27+
from typing import Callable
2628
from typing import TypeVar
2729

30+
import pandas as pd
31+
import pyarrow as pa
2832
from typing_extensions import Self
33+
from typing_extensions import TypeAlias
2934

3035
from narwhals._polars.group_by import PolarsGroupBy
3136
from narwhals._polars.group_by import PolarsLazyGroupBy
@@ -37,35 +42,42 @@
3742
from narwhals.utils import Version
3843

3944
T = TypeVar("T")
45+
R = TypeVar("R")
46+
47+
Method: TypeAlias = "Callable[..., R]"
48+
"""Generic alias representing all methods implemented via `__getattr__`.
49+
50+
Where `R` is the return type.
51+
"""
4052

4153

42-
# TODO @dangotbanned: Want to tell the type checker that `__getattr__` will satisfy specific methods
43-
# - Could these have a narrower annotation?
4454
class PolarsDataFrame:
45-
clone: Any
46-
collect: Any
47-
drop_nulls: Any
48-
estimated_size: Any
49-
filter: Any
50-
gather_every: Any
51-
item: Any
52-
iter_rows: Any
53-
is_unique: Any
54-
join: Any
55-
join_asof: Any
56-
rename: Any
57-
row: Any
58-
rows: Any
59-
sample: Any
60-
select: Any
61-
sort: Any
62-
to_arrow: Any
63-
to_numpy: Any
64-
to_pandas: Any
65-
unique: Any
66-
with_columns: Any
67-
write_csv: Any
68-
write_parquet: Any
55+
clone: Method[Self]
56+
collect: Method[CompliantDataFrame[Any, Any]]
57+
drop_nulls: Method[Self]
58+
estimated_size: Method[int | float]
59+
filter: Method[Self]
60+
gather_every: Method[Self]
61+
item: Method[Any]
62+
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
63+
is_unique: Method[PolarsSeries]
64+
join: Method[Self]
65+
join_asof: Method[Self]
66+
rename: Method[Self]
67+
row: Method[tuple[Any, ...]]
68+
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
69+
sample: Method[Self]
70+
select: Method[Self]
71+
sort: Method[Self]
72+
to_arrow: Method[pa.Table]
73+
to_numpy: Method[_2DArray]
74+
to_pandas: Method[pd.DataFrame]
75+
unique: Method[Self]
76+
with_columns: Method[Self]
77+
# NOTE: `write_csv` requires an `@overload` for `str | None`
78+
# Can't do that here 😟
79+
write_csv: Method[Any]
80+
write_parquet: Method[None]
6981

7082
def __init__(
7183
self: Self,
@@ -249,7 +261,7 @@ def simple_select(self, *column_names: str) -> Self:
249261
return self._from_native_frame(self._native_frame.select(*column_names))
250262

251263
def aggregate(self: Self, *exprs: Any) -> Self:
252-
return self.select(*exprs) # type: ignore[no-any-return]
264+
return self.select(*exprs)
253265

254266
def get_column(self: Self, name: str) -> PolarsSeries:
255267
from narwhals._polars.series import PolarsSeries

0 commit comments

Comments
 (0)