|
6 | 6 | from typing import Iterable |
7 | 7 | from typing import Iterator |
8 | 8 | from typing import TypeVar |
9 | | -from typing import cast |
10 | 9 |
|
11 | 10 | from narwhals._expression_parsing import all_exprs_are_scalar_like |
12 | | -from narwhals.dataframe import DataFrame |
13 | | -from narwhals.dataframe import LazyFrame |
14 | 11 | from narwhals.exceptions import InvalidOperationError |
| 12 | +from narwhals.typing import DataFrameT |
15 | 13 | from narwhals.utils import flatten |
16 | 14 | from narwhals.utils import tupleify |
17 | 15 |
|
18 | 16 | if TYPE_CHECKING: |
19 | 17 | from typing_extensions import Self |
20 | 18 |
|
21 | | - from narwhals.dataframe import DataFrame |
22 | 19 | from narwhals.dataframe import LazyFrame |
23 | 20 | from narwhals.expr import Expr |
24 | 21 |
|
25 | | -DataFrameT = TypeVar("DataFrameT") |
26 | | -LazyFrameT = TypeVar("LazyFrameT") |
| 22 | +LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]") |
27 | 23 |
|
28 | 24 |
|
29 | 25 | class GroupBy(Generic[DataFrameT]): |
30 | 26 | def __init__(self: Self, df: DataFrameT, *keys: str, drop_null_keys: bool) -> None: |
31 | | - self._df = cast("DataFrame[Any]", df) |
| 27 | + self._df: DataFrameT = df |
32 | 28 | self._keys = keys |
33 | 29 | self._grouped = self._df._compliant_frame.group_by( |
34 | 30 | *self._keys, drop_null_keys=drop_null_keys |
@@ -131,20 +127,18 @@ def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFra |
131 | 127 | for key, value in named_aggs.items() |
132 | 128 | ), |
133 | 129 | ) |
134 | | - return self._df._from_compliant_dataframe( # type: ignore[return-value] |
135 | | - self._grouped.agg(*compliant_aggs), |
136 | | - ) |
| 130 | + return self._df._from_compliant_dataframe(self._grouped.agg(*compliant_aggs)) |
137 | 131 |
|
138 | 132 | def __iter__(self: Self) -> Iterator[tuple[Any, DataFrameT]]: |
139 | | - yield from ( # type: ignore[misc] |
| 133 | + yield from ( |
140 | 134 | (tupleify(key), self._df._from_compliant_dataframe(df)) |
141 | 135 | for (key, df) in self._grouped.__iter__() |
142 | 136 | ) |
143 | 137 |
|
144 | 138 |
|
145 | 139 | class LazyGroupBy(Generic[LazyFrameT]): |
146 | 140 | def __init__(self: Self, df: LazyFrameT, *keys: str, drop_null_keys: bool) -> None: |
147 | | - self._df = cast("LazyFrame[Any]", df) |
| 141 | + self._df: LazyFrameT = df |
148 | 142 | self._keys = keys |
149 | 143 | self._grouped = self._df._compliant_frame.group_by( |
150 | 144 | *self._keys, drop_null_keys=drop_null_keys |
@@ -231,6 +225,4 @@ def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFra |
231 | 225 | for key, value in named_aggs.items() |
232 | 226 | ), |
233 | 227 | ) |
234 | | - return self._df._from_compliant_dataframe( # type: ignore[return-value] |
235 | | - self._grouped.agg(*compliant_aggs), |
236 | | - ) |
| 228 | + return self._df._from_compliant_dataframe(self._grouped.agg(*compliant_aggs)) |
0 commit comments