Skip to content

Commit 7611bd4

Browse files
type: Use SQLFrame instead of PySpark to type _spark_like internally (#2190)
Co-authored-by: dangotbanned <[email protected]>
1 parent ad5c2d7 commit 7611bd4

File tree

18 files changed

+214
-179
lines changed

18 files changed

+214
-179
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
cache-dependency-glob: "pyproject.toml"
5454
- name: install-reqs
5555
# we are not testing pyspark on Windows here because it is very slow
56-
run: uv pip install -e ".[tests, core, extra, dask, modin, sqlframe]" --system
56+
run: uv pip install -e ".[tests, core, extra, dask, modin]" --system
5757
- name: show-deps
5858
run: uv pip freeze
5959
- name: Run pytest
@@ -83,7 +83,7 @@ jobs:
8383
cache-suffix: ${{ matrix.python-version }}
8484
cache-dependency-glob: "pyproject.toml"
8585
- name: install-reqs
86-
run: uv pip install -e ".[tests, core, extra, modin, dask, sqlframe]" --system
86+
run: uv pip install -e ".[tests, core, extra, modin, dask]" --system
8787
- name: install pyspark
8888
run: uv pip install -e ".[pyspark]" --system
8989
# PySpark is not yet available on Python3.12+

.github/workflows/typing.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
mypy:
1313
strategy:
1414
matrix:
15-
python-version: ["3.11"]
15+
python-version: ["3.12"]
1616
os: [ubuntu-latest]
1717
runs-on: ${{ matrix.os }}
1818
steps:
@@ -32,7 +32,7 @@ jobs:
3232
# TODO: add more dependencies/backends incrementally
3333
run: |
3434
source .venv/bin/activate
35-
uv pip install -e ".[tests, typing, core, pyspark, sqlframe]"
35+
uv pip install -e ".[typing, core, pyspark]"
3636
- name: show-deps
3737
run: |
3838
source .venv/bin/activate

CONTRIBUTING.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,26 @@ then their tests will run too.
207207

208208
We can't currently test in CI against cuDF, but you can test it manually in Kaggle using GPUs. Please follow this [Kaggle notebook](https://www.kaggle.com/code/marcogorelli/testing-cudf-in-narwhals) to run the tests.
209209

210+
### Static typing
211+
212+
We run both `mypy` and `pyright` in CI. To run them locally, make sure to install
213+
214+
```terminal
215+
uv pip install -U -e ".[typing]"
216+
```
217+
218+
You can then run
219+
- `mypy narwhals tests`
220+
- `pyright narwhals tests`
221+
222+
to verify type completeness / correctness.
223+
224+
Note that:
225+
- In `_pandas_like`, we type all native objects as if they are pandas ones, though
226+
in reality this package is shared between pandas, Modin, and cuDF.
227+
- In `_spark_like`, we type all native objects as if they are SQLFrame ones, though
228+
in reality this package is shared between SQLFrame and PySpark.
229+
210230
### 8. Writing the doc(strings)
211231

212232
If you are adding a new feature or changing an existing one, you should also update the documentation and the docstrings

narwhals/_spark_like/dataframe.py

Lines changed: 51 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations
22

33
import warnings
4-
from importlib import import_module
54
from typing import TYPE_CHECKING
65
from typing import Any
76
from typing import Iterator
87
from typing import Literal
98
from typing import Sequence
10-
from typing import cast
119

1210
from narwhals._spark_like.utils import evaluate_exprs
11+
from narwhals._spark_like.utils import import_functions
12+
from narwhals._spark_like.utils import import_native_dtypes
13+
from narwhals._spark_like.utils import import_window
1314
from narwhals._spark_like.utils import native_to_narwhals_dtype
1415
from narwhals.exceptions import InvalidOperationError
1516
from narwhals.typing import CompliantDataFrame
@@ -26,11 +27,10 @@
2627
from types import ModuleType
2728

2829
import pyarrow as pa
29-
from pyspark.sql import Column
30-
from pyspark.sql import DataFrame
31-
from pyspark.sql import Window
32-
from pyspark.sql.session import SparkSession
33-
from sqlframe.base.dataframe import BaseDataFrame as _SQLFrameDataFrame
30+
from sqlframe.base.column import Column
31+
from sqlframe.base.dataframe import BaseDataFrame
32+
from sqlframe.base.session import _BaseSession
33+
from sqlframe.base.window import Window
3434
from typing_extensions import Self
3535
from typing_extensions import TypeAlias
3636

@@ -40,8 +40,8 @@
4040
from narwhals.dtypes import DType
4141
from narwhals.utils import Version
4242

43-
SQLFrameDataFrame: TypeAlias = _SQLFrameDataFrame[Any, Any, Any, Any, Any]
44-
_NativeDataFrame: TypeAlias = "DataFrame | SQLFrameDataFrame"
43+
SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any]
44+
SQLFrameSession = _BaseSession[Any, Any, Any, Any, Any, Any, Any]
4545

4646
Incomplete: TypeAlias = Any # pragma: no cover
4747
"""Marker for working code that fails type checking."""
@@ -50,15 +50,15 @@
5050
class SparkLikeLazyFrame(CompliantLazyFrame):
5151
def __init__(
5252
self: Self,
53-
native_dataframe: _NativeDataFrame,
53+
native_dataframe: SQLFrameDataFrame,
5454
*,
5555
backend_version: tuple[int, ...],
5656
version: Version,
5757
implementation: Implementation,
5858
# Unused, just for compatibility. We only validate when collecting.
5959
validate_column_names: bool = False,
6060
) -> None:
61-
self._native_frame = native_dataframe
61+
self._native_frame: SQLFrameDataFrame = native_dataframe
6262
self._backend_version = backend_version
6363
self._implementation = implementation
6464
self._version = version
@@ -68,58 +68,38 @@ def __init__(
6868
@property
6969
def _F(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202, N802
7070
if TYPE_CHECKING:
71-
from pyspark.sql import functions
71+
from sqlframe.base import functions
7272

7373
return functions
74-
if self._implementation is Implementation.SQLFRAME:
75-
from sqlframe.base.session import _BaseSession
76-
77-
return import_module(
78-
f"sqlframe.{_BaseSession().execution_dialect_name}.functions"
79-
)
80-
81-
from pyspark.sql import functions
82-
83-
return functions
74+
else:
75+
return import_functions(self._implementation)
8476

8577
@property
8678
def _native_dtypes(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202
8779
if TYPE_CHECKING:
88-
from pyspark.sql import types
80+
from sqlframe.base import types
8981

9082
return types
91-
92-
if self._implementation is Implementation.SQLFRAME:
93-
from sqlframe.base.session import _BaseSession
94-
95-
return import_module(
96-
f"sqlframe.{_BaseSession().execution_dialect_name}.types"
97-
)
98-
99-
from pyspark.sql import types
100-
101-
return types
83+
else:
84+
return import_native_dtypes(self._implementation)
10285

10386
@property
10487
def _Window(self: Self) -> type[Window]: # noqa: N802
105-
if self._implementation is Implementation.SQLFRAME:
106-
from sqlframe.base.session import _BaseSession
107-
108-
_window = import_module(
109-
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
110-
)
111-
return _window.Window
112-
113-
from pyspark.sql import Window
88+
if TYPE_CHECKING:
89+
from sqlframe.base.window import Window
11490

115-
return Window
91+
return Window
92+
else:
93+
return import_window(self._implementation)
11694

11795
@property
118-
def _session(self: Self) -> SparkSession:
96+
def _session(self: Self) -> SQLFrameSession:
97+
if TYPE_CHECKING:
98+
return self._native_frame.session
11999
if self._implementation is Implementation.SQLFRAME:
120-
return cast("SQLFrameDataFrame", self._native_frame).session
100+
return self._native_frame.session
121101

122-
return cast("DataFrame", self._native_frame).sparkSession
102+
return self._native_frame.sparkSession
123103

124104
def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
125105
return self._implementation.to_native_namespace()
@@ -144,7 +124,7 @@ def _change_version(self: Self, version: Version) -> Self:
144124
implementation=self._implementation,
145125
)
146126

147-
def _from_native_frame(self: Self, df: DataFrame) -> Self:
127+
def _from_native_frame(self: Self, df: SQLFrameDataFrame) -> Self:
148128
return self.__class__(
149129
df,
150130
backend_version=self._backend_version,
@@ -158,7 +138,7 @@ def _collect_to_arrow(self) -> pa.Table:
158138
):
159139
import pyarrow as pa # ignore-banned-import
160140

161-
native_frame = cast("DataFrame", self._native_frame)
141+
native_frame = self._native_frame
162142
try:
163143
return pa.Table.from_batches(native_frame._collect_as_arrow())
164144
except ValueError as exc:
@@ -174,13 +154,12 @@ def _collect_to_arrow(self) -> pa.Table:
174154
try:
175155
native_dtype = narwhals_to_native_dtype(value, self._version)
176156
except Exception as exc: # noqa: BLE001
177-
native_spark_dtype = native_frame.schema[key].dataType
157+
native_spark_dtype = native_frame.schema[key].dataType # type: ignore[index]
178158
# If we can't convert the type, just set it to `pa.null`, and warn.
179159
# Avoid the warning if we're starting from PySpark's void type.
180160
# We can avoid the check when we introduce `nw.Null` dtype.
181-
if not isinstance(
182-
native_spark_dtype, self._native_dtypes.NullType
183-
):
161+
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
162+
if not isinstance(native_spark_dtype, null_type):
184163
warnings.warn(
185164
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
186165
stacklevel=find_stacklevel(),
@@ -192,9 +171,7 @@ def _collect_to_arrow(self) -> pa.Table:
192171
else: # pragma: no cover
193172
raise
194173
else:
195-
# NOTE: See https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1969224309
196-
to_arrow: Incomplete = self._native_frame.toArrow
197-
return to_arrow()
174+
return self._native_frame.toArrow()
198175

199176
def _iter_columns(self) -> Iterator[Column]:
200177
for col in self.columns:
@@ -250,7 +227,7 @@ def collect(
250227
raise ValueError(msg) # pragma: no cover
251228

252229
def simple_select(self: Self, *column_names: str) -> Self:
253-
return self._from_native_frame(self._native_frame.select(*column_names)) # pyright: ignore[reportArgumentType]
230+
return self._from_native_frame(self._native_frame.select(*column_names))
254231

255232
def aggregate(
256233
self: Self,
@@ -259,7 +236,7 @@ def aggregate(
259236
new_columns = evaluate_exprs(self, *exprs)
260237

261238
new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
262-
return self._from_native_frame(self._native_frame.agg(*new_columns_list)) # pyright: ignore[reportArgumentType]
239+
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
263240

264241
def select(
265242
self: Self,
@@ -274,17 +251,17 @@ def select(
274251
return self._from_native_frame(spark_df)
275252

276253
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
277-
return self._from_native_frame(self._native_frame.select(*new_columns_list)) # pyright: ignore[reportArgumentType]
254+
return self._from_native_frame(self._native_frame.select(*new_columns_list))
278255

279256
def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self:
280257
new_columns = evaluate_exprs(self, *exprs)
281-
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns))) # pyright: ignore[reportArgumentType]
258+
return self._from_native_frame(self._native_frame.withColumns(dict(new_columns)))
282259

283260
def filter(self: Self, predicate: SparkLikeExpr) -> Self:
284261
# `[0]` is safe as the predicate's expression only returns a single column
285262
condition = predicate._call(self)[0]
286-
spark_df = self._native_frame.where(condition) # pyright: ignore[reportArgumentType]
287-
return self._from_native_frame(spark_df) # pyright: ignore[reportArgumentType]
263+
spark_df = self._native_frame.where(condition)
264+
return self._from_native_frame(spark_df)
288265

289266
@property
290267
def schema(self: Self) -> dict[str, DType]:
@@ -293,8 +270,7 @@ def schema(self: Self) -> dict[str, DType]:
293270
field.name: native_to_narwhals_dtype(
294271
dtype=field.dataType,
295272
version=self._version,
296-
# NOTE: Unclear if this is an unsafe hash (https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1970074662)
297-
spark_types=self._native_dtypes, # pyright: ignore[reportArgumentType]
273+
spark_types=self._native_dtypes,
298274
)
299275
for field in self._native_frame.schema
300276
}
@@ -307,10 +283,10 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
307283
columns_to_drop = parse_columns_to_drop(
308284
compliant_frame=self, columns=columns, strict=strict
309285
)
310-
return self._from_native_frame(self._native_frame.drop(*columns_to_drop)) # pyright: ignore[reportArgumentType]
286+
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))
311287

312288
def head(self: Self, n: int) -> Self:
313-
return self._from_native_frame(self._native_frame.limit(num=n)) # pyright: ignore[reportArgumentType]
289+
return self._from_native_frame(self._native_frame.limit(num=n))
314290

315291
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy:
316292
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
@@ -340,18 +316,18 @@ def sort(
340316
)
341317

342318
sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)]
343-
return self._from_native_frame(self._native_frame.sort(*sort_cols)) # pyright: ignore[reportArgumentType]
319+
return self._from_native_frame(self._native_frame.sort(*sort_cols))
344320

345321
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
346-
return self._from_native_frame(self._native_frame.dropna(subset=subset)) # pyright: ignore[reportArgumentType]
322+
return self._from_native_frame(self._native_frame.dropna(subset=subset))
347323

348324
def rename(self: Self, mapping: dict[str, str]) -> Self:
349325
rename_mapping = {
350326
colname: mapping.get(colname, colname) for colname in self.columns
351327
}
352328
return self._from_native_frame(
353329
self._native_frame.select(
354-
[self._F.col(old).alias(new) for old, new in rename_mapping.items()] # pyright: ignore[reportArgumentType]
330+
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
355331
)
356332
)
357333

@@ -365,7 +341,7 @@ def unique(
365341
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
366342
raise ValueError(msg)
367343
check_column_exists(self.columns, subset)
368-
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset)) # pyright: ignore[reportArgumentType]
344+
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))
369345

370346
def join(
371347
self: Self,
@@ -409,7 +385,7 @@ def join(
409385
]
410386
)
411387
return self._from_native_frame(
412-
self_native.join(other_native, on=left_on, how=how).select(col_order) # pyright: ignore[reportArgumentType]
388+
self_native.join(other_native, on=left_on, how=how).select(col_order)
413389
)
414390

415391
def explode(self: Self, columns: list[str]) -> Self:
@@ -445,7 +421,7 @@ def explode(self: Self, columns: list[str]) -> Self:
445421
else self._F.explode_outer(col_name).alias(col_name)
446422
for col_name in column_names
447423
]
448-
), # pyright: ignore[reportArgumentType]
424+
)
449425
)
450426
elif self._implementation.is_sqlframe():
451427
# Not every sqlframe dialect supports `explode_outer` function
@@ -466,14 +442,14 @@ def null_condition(col_name: str) -> Column:
466442
for col_name in column_names
467443
]
468444
).union(
469-
native_frame.filter(null_condition(columns[0])).select( # pyright: ignore[reportArgumentType]
445+
native_frame.filter(null_condition(columns[0])).select(
470446
*[
471447
self._F.col(col_name).alias(col_name)
472448
if col_name != columns[0]
473449
else self._F.lit(None).alias(col_name)
474450
for col_name in column_names
475451
]
476-
) # pyright: ignore[reportArgumentType]
452+
)
477453
),
478454
)
479455
else: # pragma: no cover
@@ -508,4 +484,4 @@ def unpivot(
508484
)
509485
if index is None:
510486
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
511-
return self._from_native_frame(unpivoted_native_frame) # pyright: ignore[reportArgumentType]
487+
return self._from_native_frame(unpivoted_native_frame)

0 commit comments

Comments
 (0)