Skip to content

Commit bbb520b

Browse files
authored
chore(typing): fix group_by (#2105)
- Think this is pretty straightforward - Tiny reduction in LOC is a bonus
1 parent 746d196 commit bbb520b

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

narwhals/group_by.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,25 @@
66
from typing import Iterable
77
from typing import Iterator
88
from typing import TypeVar
9-
from typing import cast
109

1110
from narwhals._expression_parsing import all_exprs_are_scalar_like
12-
from narwhals.dataframe import DataFrame
13-
from narwhals.dataframe import LazyFrame
1411
from narwhals.exceptions import InvalidOperationError
12+
from narwhals.typing import DataFrameT
1513
from narwhals.utils import flatten
1614
from narwhals.utils import tupleify
1715

1816
if TYPE_CHECKING:
1917
from typing_extensions import Self
2018

21-
from narwhals.dataframe import DataFrame
2219
from narwhals.dataframe import LazyFrame
2320
from narwhals.expr import Expr
2421

25-
DataFrameT = TypeVar("DataFrameT")
26-
LazyFrameT = TypeVar("LazyFrameT")
22+
LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]")
2723

2824

2925
class GroupBy(Generic[DataFrameT]):
3026
def __init__(self: Self, df: DataFrameT, *keys: str, drop_null_keys: bool) -> None:
31-
self._df = cast("DataFrame[Any]", df)
27+
self._df: DataFrameT = df
3228
self._keys = keys
3329
self._grouped = self._df._compliant_frame.group_by(
3430
*self._keys, drop_null_keys=drop_null_keys
@@ -131,20 +127,18 @@ def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFra
131127
for key, value in named_aggs.items()
132128
),
133129
)
134-
return self._df._from_compliant_dataframe( # type: ignore[return-value]
135-
self._grouped.agg(*compliant_aggs),
136-
)
130+
return self._df._from_compliant_dataframe(self._grouped.agg(*compliant_aggs))
137131

138132
def __iter__(self: Self) -> Iterator[tuple[Any, DataFrameT]]:
139-
yield from ( # type: ignore[misc]
133+
yield from (
140134
(tupleify(key), self._df._from_compliant_dataframe(df))
141135
for (key, df) in self._grouped.__iter__()
142136
)
143137

144138

145139
class LazyGroupBy(Generic[LazyFrameT]):
146140
def __init__(self: Self, df: LazyFrameT, *keys: str, drop_null_keys: bool) -> None:
147-
self._df = cast("LazyFrame[Any]", df)
141+
self._df: LazyFrameT = df
148142
self._keys = keys
149143
self._grouped = self._df._compliant_frame.group_by(
150144
*self._keys, drop_null_keys=drop_null_keys
@@ -231,6 +225,4 @@ def agg(self: Self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFra
231225
for key, value in named_aggs.items()
232226
),
233227
)
234-
return self._df._from_compliant_dataframe( # type: ignore[return-value]
235-
self._grouped.agg(*compliant_aggs),
236-
)
228+
return self._df._from_compliant_dataframe(self._grouped.agg(*compliant_aggs))

0 commit comments

Comments
 (0)