Skip to content

Commit 6a9c09f

Browse files
committed
chore(typing): make some more parts generic
1 parent c4604bc commit 6a9c09f

File tree

3 files changed

+49
-40
lines changed

3 files changed

+49
-40
lines changed

narwhals/_spark_like/dataframe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def simple_select(self: Self, *column_names: str) -> Self:
257257

258258
def aggregate(
259259
self: Self,
260-
*exprs: SparkLikeExpr,
260+
*exprs: SparkLikeExpr[FrameT],
261261
) -> Self:
262262
new_columns = evaluate_exprs(self, *exprs)
263263

@@ -266,7 +266,7 @@ def aggregate(
266266

267267
def select(
268268
self: Self,
269-
*exprs: SparkLikeExpr,
269+
*exprs: SparkLikeExpr[FrameT],
270270
) -> Self:
271271
new_columns = evaluate_exprs(self, *exprs)
272272

@@ -279,11 +279,11 @@ def select(
279279
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
280280
return self._from_native_frame(self._native_frame.select(*new_columns_list))
281281

282-
def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self:
282+
def with_columns(self: Self, *exprs: SparkLikeExpr[FrameT]) -> Self:
283283
new_columns = evaluate_exprs(self, *exprs)
284284
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns)))
285285

286-
def filter(self: Self, predicate: SparkLikeExpr) -> Self:
286+
def filter(self: Self, predicate: SparkLikeExpr[FrameT]) -> Self:
287287
# `[0]` is safe as the predicate's expression only returns a single column
288288
condition = predicate._call(self)[0]
289289
spark_df = self._native_frame.where(condition)

narwhals/_spark_like/expr.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,23 @@
2828
from pyspark.sql import Window
2929
from typing_extensions import Self
3030

31+
from narwhals._spark_like.dataframe import FrameT
3132
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
3233
from narwhals._spark_like.namespace import SparkLikeNamespace
3334
from narwhals._spark_like.typing import WindowFunction
3435
from narwhals.dtypes import DType
3536
from narwhals.utils import Version
3637

3738

38-
class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044)
39+
class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame[FrameT]", "Column"]): # type: ignore[type-var] # (#2044)
3940
_depth = 0 # Unused, just for compatibility with CompliantExpr
4041

4142
def __init__(
4243
self: Self,
43-
call: Callable[[SparkLikeLazyFrame], Sequence[Column]],
44+
call: Callable[[SparkLikeLazyFrame[FrameT]], Sequence[Column]],
4445
*,
4546
function_name: str,
46-
evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]],
47+
evaluate_output_names: Callable[[SparkLikeLazyFrame[FrameT]], Sequence[str]],
4748
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
4849
backend_version: tuple[int, ...],
4950
version: Version,
@@ -58,11 +59,11 @@ def __init__(
5859
self._implementation = implementation
5960
self._window_function: WindowFunction | None = None
6061

61-
def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]:
62+
def __call__(self: Self, df: SparkLikeLazyFrame[FrameT]) -> Sequence[Column]:
6263
return self._call(df)
6364

6465
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
65-
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
66+
def func(df: SparkLikeLazyFrame[FrameT]) -> Sequence[Column]:
6667
if kind is ExprKind.AGGREGATION:
6768
return [
6869
result.over(df._Window().partitionBy(df._F.lit(1)))
@@ -144,15 +145,15 @@ def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: # pragma: no cove
144145
@classmethod
145146
def from_column_names(
146147
cls: type[Self],
147-
evaluate_column_names: Callable[[SparkLikeLazyFrame], Sequence[str]],
148+
evaluate_column_names: Callable[[SparkLikeLazyFrame[FrameT]], Sequence[str]],
148149
/,
149150
*,
150151
function_name: str,
151152
implementation: Implementation,
152153
backend_version: tuple[int, ...],
153154
version: Version,
154155
) -> Self:
155-
def func(df: SparkLikeLazyFrame) -> list[Column]:
156+
def func(df: SparkLikeLazyFrame[FrameT]) -> list[Column]:
156157
return [df._F.col(col_name) for col_name in evaluate_column_names(df)]
157158

158159
return cls(
@@ -173,7 +174,7 @@ def from_column_indices(
173174
version: Version,
174175
implementation: Implementation,
175176
) -> Self:
176-
def func(df: SparkLikeLazyFrame) -> list[Column]:
177+
def func(df: SparkLikeLazyFrame[FrameT]) -> list[Column]:
177178
columns = df.columns
178179
return [df._F.col(columns[i]) for i in column_indices]
179180

@@ -193,7 +194,7 @@ def _from_call(
193194
expr_name: str,
194195
**expressifiable_args: Self | Any,
195196
) -> Self:
196-
def func(df: SparkLikeLazyFrame) -> list[Column]:
197+
def func(df: SparkLikeLazyFrame[FrameT]) -> list[Column]:
197198
native_series_list = self._call(df)
198199
other_native_series = {
199200
key: maybe_evaluate_expr(df, value)
@@ -230,104 +231,104 @@ def _with_window_function(
230231
result._window_function = window_function
231232
return result
232233

233-
def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override]
234+
def __eq__(self: Self, other: SparkLikeExpr[FrameT]) -> Self: # type: ignore[override]
234235
return self._from_call(
235236
lambda _input, other: _input.__eq__(other), "__eq__", other=other
236237
)
237238

238-
def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override]
239+
def __ne__(self: Self, other: SparkLikeExpr[FrameT]) -> Self: # type: ignore[override]
239240
return self._from_call(
240241
lambda _input, other: _input.__ne__(other), "__ne__", other=other
241242
)
242243

243-
def __add__(self: Self, other: SparkLikeExpr) -> Self:
244+
def __add__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
244245
return self._from_call(
245246
lambda _input, other: _input.__add__(other), "__add__", other=other
246247
)
247248

248-
def __sub__(self: Self, other: SparkLikeExpr) -> Self:
249+
def __sub__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
249250
return self._from_call(
250251
lambda _input, other: _input.__sub__(other), "__sub__", other=other
251252
)
252253

253-
def __rsub__(self: Self, other: SparkLikeExpr) -> Self:
254+
def __rsub__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
254255
return self._from_call(
255256
lambda _input, other: other.__sub__(_input), "__rsub__", other=other
256257
).alias("literal")
257258

258-
def __mul__(self: Self, other: SparkLikeExpr) -> Self:
259+
def __mul__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
259260
return self._from_call(
260261
lambda _input, other: _input.__mul__(other), "__mul__", other=other
261262
)
262263

263-
def __truediv__(self: Self, other: SparkLikeExpr) -> Self:
264+
def __truediv__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
264265
return self._from_call(
265266
lambda _input, other: _input.__truediv__(other), "__truediv__", other=other
266267
)
267268

268-
def __rtruediv__(self: Self, other: SparkLikeExpr) -> Self:
269+
def __rtruediv__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
269270
return self._from_call(
270271
lambda _input, other: other.__truediv__(_input), "__rtruediv__", other=other
271272
).alias("literal")
272273

273-
def __floordiv__(self: Self, other: SparkLikeExpr) -> Self:
274+
def __floordiv__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
274275
def _floordiv(_input: Column, other: Column) -> Column:
275276
return self._F.floor(_input / other)
276277

277278
return self._from_call(_floordiv, "__floordiv__", other=other)
278279

279-
def __rfloordiv__(self: Self, other: SparkLikeExpr) -> Self:
280+
def __rfloordiv__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
280281
def _rfloordiv(_input: Column, other: Column) -> Column:
281282
return self._F.floor(other / _input)
282283

283284
return self._from_call(_rfloordiv, "__rfloordiv__", other=other).alias("literal")
284285

285-
def __pow__(self: Self, other: SparkLikeExpr) -> Self:
286+
def __pow__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
286287
return self._from_call(
287288
lambda _input, other: _input.__pow__(other), "__pow__", other=other
288289
)
289290

290-
def __rpow__(self: Self, other: SparkLikeExpr) -> Self:
291+
def __rpow__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
291292
return self._from_call(
292293
lambda _input, other: other.__pow__(_input), "__rpow__", other=other
293294
).alias("literal")
294295

295-
def __mod__(self: Self, other: SparkLikeExpr) -> Self:
296+
def __mod__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
296297
return self._from_call(
297298
lambda _input, other: _input.__mod__(other), "__mod__", other=other
298299
)
299300

300-
def __rmod__(self: Self, other: SparkLikeExpr) -> Self:
301+
def __rmod__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
301302
return self._from_call(
302303
lambda _input, other: other.__mod__(_input), "__rmod__", other=other
303304
).alias("literal")
304305

305-
def __ge__(self: Self, other: SparkLikeExpr) -> Self:
306+
def __ge__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
306307
return self._from_call(
307308
lambda _input, other: _input.__ge__(other), "__ge__", other=other
308309
)
309310

310-
def __gt__(self: Self, other: SparkLikeExpr) -> Self:
311+
def __gt__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
311312
return self._from_call(
312313
lambda _input, other: _input > other, "__gt__", other=other
313314
)
314315

315-
def __le__(self: Self, other: SparkLikeExpr) -> Self:
316+
def __le__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
316317
return self._from_call(
317318
lambda _input, other: _input.__le__(other), "__le__", other=other
318319
)
319320

320-
def __lt__(self: Self, other: SparkLikeExpr) -> Self:
321+
def __lt__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
321322
return self._from_call(
322323
lambda _input, other: _input.__lt__(other), "__lt__", other=other
323324
)
324325

325-
def __and__(self: Self, other: SparkLikeExpr) -> Self:
326+
def __and__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
326327
return self._from_call(
327328
lambda _input, other: _input.__and__(other), "__and__", other=other
328329
)
329330

330-
def __or__(self: Self, other: SparkLikeExpr) -> Self:
331+
def __or__(self: Self, other: SparkLikeExpr[FrameT]) -> Self:
331332
return self._from_call(
332333
lambda _input, other: _input.__or__(other), "__or__", other=other
333334
)
@@ -524,14 +525,14 @@ def over(
524525
if (window_function := self._window_function) is not None:
525526
assert order_by is not None # noqa: S101
526527

527-
def func(df: SparkLikeLazyFrame) -> list[Column]:
528+
def func(df: SparkLikeLazyFrame[FrameT]) -> list[Column]:
528529
return [
529530
window_function(expr, partition_by, order_by)
530531
for expr in self._call(df)
531532
]
532533
else:
533534

534-
def func(df: SparkLikeLazyFrame) -> list[Column]:
535+
def func(df: SparkLikeLazyFrame[FrameT]) -> list[Column]:
535536
return [
536537
expr.over(self._Window.partitionBy(*partition_by))
537538
for expr in self._call(df)

narwhals/_spark_like/group_by.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
from __future__ import annotations
22

33
from typing import TYPE_CHECKING
4+
from typing import Generic
5+
from typing import TypeVar
6+
from typing import cast
47

58
if TYPE_CHECKING:
69
from typing_extensions import Self
710

11+
from narwhals._spark_like.dataframe import DataFrame
812
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
13+
from narwhals._spark_like.dataframe import SQLFrameDataFrame
914
from narwhals._spark_like.expr import SparkLikeExpr
1015

16+
FrameT = TypeVar("FrameT", "DataFrame", "SQLFrameDataFrame")
1117

12-
class SparkLikeLazyGroupBy:
18+
19+
class SparkLikeLazyGroupBy(Generic[FrameT]):
1320
def __init__(
1421
self: Self,
15-
compliant_frame: SparkLikeLazyFrame,
22+
compliant_frame: SparkLikeLazyFrame[FrameT],
1623
keys: list[str],
1724
drop_null_keys: bool, # noqa: FBT001
1825
) -> None:
@@ -22,7 +29,7 @@ def __init__(
2229
self._compliant_frame = compliant_frame
2330
self._keys = keys
2431

25-
def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
32+
def agg(self: Self, *exprs: SparkLikeExpr[FrameT]) -> SparkLikeLazyFrame[FrameT]:
2633
agg_columns = []
2734
df = self._compliant_frame
2835
for expr in exprs:
@@ -52,6 +59,7 @@ def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
5259
return self._compliant_frame._from_native_frame(
5360
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates()
5461
)
55-
return self._compliant_frame._from_native_frame(
56-
self._compliant_frame._native_frame.groupBy(*self._keys).agg(*agg_columns)
62+
native = self._compliant_frame._native_frame.groupBy(*self._keys).agg(
63+
*agg_columns
5764
)
65+
return self._compliant_frame._from_native_frame(cast("FrameT", native))

0 commit comments

Comments
 (0)