|
4 | 4 | from typing import Any |
5 | 5 | from typing import Iterator |
6 | 6 | from typing import Literal |
| 7 | +from typing import Mapping |
7 | 8 | from typing import Sequence |
8 | 9 | from typing import overload |
9 | 10 |
|
|
23 | 24 |
|
24 | 25 | if TYPE_CHECKING: |
25 | 26 | from types import ModuleType |
| 27 | + from typing import Callable |
26 | 28 | from typing import TypeVar |
27 | 29 |
|
| 30 | + import pandas as pd |
| 31 | + import pyarrow as pa |
28 | 32 | from typing_extensions import Self |
| 33 | + from typing_extensions import TypeAlias |
29 | 34 |
|
30 | 35 | from narwhals._polars.group_by import PolarsGroupBy |
31 | 36 | from narwhals._polars.group_by import PolarsLazyGroupBy |
|
37 | 42 | from narwhals.utils import Version |
38 | 43 |
|
39 | 44 | T = TypeVar("T") |
| 45 | + R = TypeVar("R") |
| 46 | + |
| 47 | +Method: TypeAlias = "Callable[..., R]" |
| 48 | +"""Generic alias representing all methods implemented via `__getattr__`. |
| 49 | +
|
| 50 | +Where `R` is the return type. |
| 51 | +""" |
40 | 52 |
|
41 | 53 |
|
42 | | -# TODO @dangotbanned: Want to tell the type checker that `__getattr__` will satisfy specific methods |
43 | | -# - Could these have a narrower annotation? |
44 | 54 | class PolarsDataFrame: |
45 | | - clone: Any |
46 | | - collect: Any |
47 | | - drop_nulls: Any |
48 | | - estimated_size: Any |
49 | | - filter: Any |
50 | | - gather_every: Any |
51 | | - item: Any |
52 | | - iter_rows: Any |
53 | | - is_unique: Any |
54 | | - join: Any |
55 | | - join_asof: Any |
56 | | - rename: Any |
57 | | - row: Any |
58 | | - rows: Any |
59 | | - sample: Any |
60 | | - select: Any |
61 | | - sort: Any |
62 | | - to_arrow: Any |
63 | | - to_numpy: Any |
64 | | - to_pandas: Any |
65 | | - unique: Any |
66 | | - with_columns: Any |
67 | | - write_csv: Any |
68 | | - write_parquet: Any |
| 55 | + clone: Method[Self] |
| 56 | + collect: Method[CompliantDataFrame[Any, Any]] |
| 57 | + drop_nulls: Method[Self] |
| 58 | + estimated_size: Method[int | float] |
| 59 | + filter: Method[Self] |
| 60 | + gather_every: Method[Self] |
| 61 | + item: Method[Any] |
| 62 | + iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]] |
| 63 | + is_unique: Method[PolarsSeries] |
| 64 | + join: Method[Self] |
| 65 | + join_asof: Method[Self] |
| 66 | + rename: Method[Self] |
| 67 | + row: Method[tuple[Any, ...]] |
| 68 | + rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]] |
| 69 | + sample: Method[Self] |
| 70 | + select: Method[Self] |
| 71 | + sort: Method[Self] |
| 72 | + to_arrow: Method[pa.Table] |
| 73 | + to_numpy: Method[_2DArray] |
| 74 | + to_pandas: Method[pd.DataFrame] |
| 75 | + unique: Method[Self] |
| 76 | + with_columns: Method[Self] |
| 77 | + # NOTE: `write_csv` requires an `@overload` for `str | None` |
| 78 | + # Can't do that here 😟 |
| 79 | + write_csv: Method[Any] |
| 80 | + write_parquet: Method[None] |
69 | 81 |
|
70 | 82 | def __init__( |
71 | 83 | self: Self, |
@@ -249,7 +261,7 @@ def simple_select(self, *column_names: str) -> Self: |
249 | 261 | return self._from_native_frame(self._native_frame.select(*column_names)) |
250 | 262 |
|
251 | 263 | def aggregate(self: Self, *exprs: Any) -> Self: |
252 | | - return self.select(*exprs) # type: ignore[no-any-return] |
| 264 | + return self.select(*exprs) |
253 | 265 |
|
254 | 266 | def get_column(self: Self, name: str) -> PolarsSeries: |
255 | 267 | from narwhals._polars.series import PolarsSeries |
|
0 commit comments