Skip to content

Commit 5383756

Browse files
chore(typing): enable typing checks for pyspark (#2051)
Co-authored-by: Dan Redding <[email protected]>
1 parent 72dea2a commit 5383756

File tree

8 files changed

+93
-69
lines changed

8 files changed

+93
-69
lines changed

.github/workflows/typing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
# TODO: add more dependencies/backends incrementally
3333
run: |
3434
source .venv/bin/activate
35-
uv pip install -e ".[tests, typing, core]"
35+
uv pip install -e ".[tests, typing, core, pyspark, sqlframe]"
3636
- name: show-deps
3737
run: |
3838
source .venv/bin/activate

narwhals/_spark_like/dataframe.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any
77
from typing import Literal
88
from typing import Sequence
9+
from typing import cast
910

1011
from narwhals._spark_like.utils import evaluate_exprs
1112
from narwhals._spark_like.utils import native_to_narwhals_dtype
@@ -26,19 +27,29 @@
2627
import pyarrow as pa
2728
from pyspark.sql import Column
2829
from pyspark.sql import DataFrame
30+
from pyspark.sql import Window
31+
from pyspark.sql.session import SparkSession
32+
from sqlframe.base.dataframe import BaseDataFrame as _SQLFrameDataFrame
2933
from typing_extensions import Self
34+
from typing_extensions import TypeAlias
3035

3136
from narwhals._spark_like.expr import SparkLikeExpr
3237
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
3338
from narwhals._spark_like.namespace import SparkLikeNamespace
3439
from narwhals.dtypes import DType
3540
from narwhals.utils import Version
3641

42+
SQLFrameDataFrame: TypeAlias = _SQLFrameDataFrame[Any, Any, Any, Any, Any]
43+
_NativeDataFrame: TypeAlias = "DataFrame | SQLFrameDataFrame"
44+
45+
Incomplete: TypeAlias = Any # pragma: no cover
46+
"""Marker for working code that fails type checking."""
47+
3748

3849
class SparkLikeLazyFrame(CompliantLazyFrame):
3950
def __init__(
4051
self: Self,
41-
native_dataframe: DataFrame,
52+
native_dataframe: _NativeDataFrame,
4253
*,
4354
backend_version: tuple[int, ...],
4455
version: Version,
@@ -54,7 +65,11 @@ def __init__(
5465
validate_backend_version(self._implementation, self._backend_version)
5566

5667
@property
57-
def _F(self: Self) -> Any: # noqa: N802
68+
def _F(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202, N802
69+
if TYPE_CHECKING:
70+
from pyspark.sql import functions
71+
72+
return functions
5873
if self._implementation is Implementation.SQLFRAME:
5974
from sqlframe.base.session import _BaseSession
6075

@@ -67,7 +82,12 @@ def _F(self: Self) -> Any: # noqa: N802
6782
return functions
6883

6984
@property
70-
def _native_dtypes(self: Self) -> Any:
85+
def _native_dtypes(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202
86+
if TYPE_CHECKING:
87+
from pyspark.sql import types
88+
89+
return types
90+
7191
if self._implementation is Implementation.SQLFRAME:
7292
from sqlframe.base.session import _BaseSession
7393

@@ -80,7 +100,7 @@ def _native_dtypes(self: Self) -> Any:
80100
return types
81101

82102
@property
83-
def _Window(self: Self) -> Any: # noqa: N802
103+
def _Window(self: Self) -> type[Window]: # noqa: N802
84104
if self._implementation is Implementation.SQLFRAME:
85105
from sqlframe.base.session import _BaseSession
86106

@@ -94,11 +114,11 @@ def _Window(self: Self) -> Any: # noqa: N802
94114
return Window
95115

96116
@property
97-
def _session(self: Self) -> Any:
117+
def _session(self: Self) -> SparkSession:
98118
if self._implementation is Implementation.SQLFRAME:
99-
return self._native_frame.session
119+
return cast("SQLFrameDataFrame", self._native_frame).session
100120

101-
return self._native_frame.sparkSession
121+
return cast("DataFrame", self._native_frame).sparkSession
102122

103123
def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
104124
return self._implementation.to_native_namespace()
@@ -137,10 +157,9 @@ def _collect_to_arrow(self) -> pa.Table:
137157
):
138158
import pyarrow as pa # ignore-banned-import
139159

160+
native_frame = cast("DataFrame", self._native_frame)
140161
try:
141-
native_pyarrow_frame = pa.Table.from_batches(
142-
self._native_frame._collect_as_arrow()
143-
)
162+
return pa.Table.from_batches(native_frame._collect_as_arrow())
144163
except ValueError as exc:
145164
if "at least one RecordBatch" in str(exc):
146165
# Empty dataframe
@@ -154,7 +173,7 @@ def _collect_to_arrow(self) -> pa.Table:
154173
try:
155174
native_dtype = narwhals_to_native_dtype(value, self._version)
156175
except Exception as exc: # noqa: BLE001
157-
native_spark_dtype = self._native_frame.schema[key].dataType
176+
native_spark_dtype = native_frame.schema[key].dataType
158177
# If we can't convert the type, just set it to `pa.null`, and warn.
159178
# Avoid the warning if we're starting from PySpark's void type.
160179
# We can avoid the check when we introduce `nw.Null` dtype.
@@ -168,14 +187,13 @@ def _collect_to_arrow(self) -> pa.Table:
168187
schema.append((key, pa.null()))
169188
else:
170189
schema.append((key, native_dtype))
171-
native_pyarrow_frame = pa.Table.from_pydict(
172-
data, schema=pa.schema(schema)
173-
)
190+
return pa.Table.from_pydict(data, schema=pa.schema(schema))
174191
else: # pragma: no cover
175192
raise
176193
else:
177-
native_pyarrow_frame = self._native_frame.toArrow()
178-
return native_pyarrow_frame
194+
# NOTE: See https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1969224309
195+
to_arrow: Incomplete = self._native_frame.toArrow
196+
return to_arrow()
179197

180198
@property
181199
def columns(self: Self) -> list[str]:
@@ -246,10 +264,8 @@ def select(
246264

247265
if not new_columns:
248266
# return empty dataframe, like Polars does
249-
spark_df = self._session.createDataFrame(
250-
[], self._native_dtypes.StructType([])
251-
)
252-
267+
schema = self._native_dtypes.StructType([])
268+
spark_df = self._session.createDataFrame([], schema)
253269
return self._from_native_frame(spark_df)
254270

255271
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
@@ -272,7 +288,8 @@ def schema(self: Self) -> dict[str, DType]:
272288
field.name: native_to_narwhals_dtype(
273289
dtype=field.dataType,
274290
version=self._version,
275-
spark_types=self._native_dtypes,
291+
# NOTE: Unclear if this is an unsafe hash (https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1970074662)
292+
spark_types=self._native_dtypes, # pyright: ignore[reportArgumentType]
276293
)
277294
for field in self._native_frame.schema
278295
}

narwhals/_spark_like/expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from narwhals.utils import Version
3232

3333

34-
class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]):
34+
class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044)
3535
_depth = 0 # Unused, just for compatibility with CompliantExpr
3636

3737
def __init__(
@@ -301,7 +301,7 @@ def __or__(self: Self, other: SparkLikeExpr) -> Self:
301301
)
302302

303303
def __invert__(self: Self) -> Self:
304-
invert = cast("Callable[..., SparkLikeExpr]", operator.invert)
304+
invert = cast("Callable[..., Column]", operator.invert)
305305
return self._from_call(invert, "__invert__")
306306

307307
def abs(self: Self) -> Self:

narwhals/_spark_like/namespace.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Iterable
99
from typing import Literal
1010
from typing import Sequence
11+
from typing import cast
1112

1213
from narwhals._expression_parsing import combine_alias_output_names
1314
from narwhals._expression_parsing import combine_evaluate_output_names
@@ -29,7 +30,7 @@
2930
from narwhals.utils import Version
3031

3132

32-
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]):
33+
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044)
3334
def __init__(
3435
self: Self,
3536
*,
@@ -222,7 +223,7 @@ def concat(
222223
*,
223224
how: Literal["horizontal", "vertical", "diagonal"],
224225
) -> SparkLikeLazyFrame:
225-
dfs: list[DataFrame] = [item._native_frame for item in items]
226+
dfs = cast("Sequence[DataFrame]", [item._native_frame for item in items])
226227
if how == "horizontal":
227228
msg = (
228229
"Horizontal concatenation is not supported for LazyFrame backed by "

narwhals/_spark_like/utils.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,55 +11,59 @@
1111
from types import ModuleType
1212

1313
import pyspark.sql.types as pyspark_types
14+
import sqlframe.base.types as sqlframe_types
1415
from pyspark.sql import Column
16+
from typing_extensions import TypeAlias
1517

1618
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
1719
from narwhals._spark_like.expr import SparkLikeExpr
1820
from narwhals.dtypes import DType
1921
from narwhals.utils import Version
2022

23+
_NativeDType: TypeAlias = "pyspark_types.DataType | sqlframe_types.DataType"
24+
2125

2226
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
2327
def native_to_narwhals_dtype(
24-
dtype: pyspark_types.DataType,
25-
version: Version,
26-
spark_types: ModuleType,
28+
dtype: _NativeDType, version: Version, spark_types: ModuleType
2729
) -> DType: # pragma: no cover
2830
dtypes = import_dtypes_module(version=version)
31+
if TYPE_CHECKING:
32+
native = pyspark_types
33+
else:
34+
native = spark_types
2935

30-
if isinstance(dtype, spark_types.DoubleType):
36+
if isinstance(dtype, native.DoubleType):
3137
return dtypes.Float64()
32-
if isinstance(dtype, spark_types.FloatType):
38+
if isinstance(dtype, native.FloatType):
3339
return dtypes.Float32()
34-
if isinstance(dtype, spark_types.LongType):
40+
if isinstance(dtype, native.LongType):
3541
return dtypes.Int64()
36-
if isinstance(dtype, spark_types.IntegerType):
42+
if isinstance(dtype, native.IntegerType):
3743
return dtypes.Int32()
38-
if isinstance(dtype, spark_types.ShortType):
44+
if isinstance(dtype, native.ShortType):
3945
return dtypes.Int16()
40-
if isinstance(dtype, spark_types.ByteType):
46+
if isinstance(dtype, native.ByteType):
4147
return dtypes.Int8()
42-
if isinstance(
43-
dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType)
44-
):
48+
if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)):
4549
return dtypes.String()
46-
if isinstance(dtype, spark_types.BooleanType):
50+
if isinstance(dtype, native.BooleanType):
4751
return dtypes.Boolean()
48-
if isinstance(dtype, spark_types.DateType):
52+
if isinstance(dtype, native.DateType):
4953
return dtypes.Date()
50-
if isinstance(dtype, spark_types.TimestampNTZType):
54+
if isinstance(dtype, native.TimestampNTZType):
5155
return dtypes.Datetime()
52-
if isinstance(dtype, spark_types.TimestampType):
56+
if isinstance(dtype, native.TimestampType):
5357
return dtypes.Datetime(time_zone="UTC")
54-
if isinstance(dtype, spark_types.DecimalType):
58+
if isinstance(dtype, native.DecimalType):
5559
return dtypes.Decimal()
56-
if isinstance(dtype, spark_types.ArrayType):
60+
if isinstance(dtype, native.ArrayType):
5761
return dtypes.List(
5862
inner=native_to_narwhals_dtype(
5963
dtype.elementType, version=version, spark_types=spark_types
6064
)
6165
)
62-
if isinstance(dtype, spark_types.StructType):
66+
if isinstance(dtype, native.StructType):
6367
return dtypes.Struct(
6468
fields=[
6569
dtypes.Field(
@@ -78,48 +82,50 @@ def narwhals_to_native_dtype(
7882
dtype: DType | type[DType], version: Version, spark_types: ModuleType
7983
) -> pyspark_types.DataType:
8084
dtypes = import_dtypes_module(version)
85+
if TYPE_CHECKING:
86+
native = pyspark_types
87+
else:
88+
native = spark_types
8189

8290
if isinstance_or_issubclass(dtype, dtypes.Float64):
83-
return spark_types.DoubleType()
91+
return native.DoubleType()
8492
if isinstance_or_issubclass(dtype, dtypes.Float32):
85-
return spark_types.FloatType()
93+
return native.FloatType()
8694
if isinstance_or_issubclass(dtype, dtypes.Int64):
87-
return spark_types.LongType()
95+
return native.LongType()
8896
if isinstance_or_issubclass(dtype, dtypes.Int32):
89-
return spark_types.IntegerType()
97+
return native.IntegerType()
9098
if isinstance_or_issubclass(dtype, dtypes.Int16):
91-
return spark_types.ShortType()
99+
return native.ShortType()
92100
if isinstance_or_issubclass(dtype, dtypes.Int8):
93-
return spark_types.ByteType()
101+
return native.ByteType()
94102
if isinstance_or_issubclass(dtype, dtypes.String):
95-
return spark_types.StringType()
103+
return native.StringType()
96104
if isinstance_or_issubclass(dtype, dtypes.Boolean):
97-
return spark_types.BooleanType()
105+
return native.BooleanType()
98106
if isinstance_or_issubclass(dtype, dtypes.Date):
99-
return spark_types.DateType()
107+
return native.DateType()
100108
if isinstance_or_issubclass(dtype, dtypes.Datetime):
101109
dt_time_zone = dtype.time_zone
102110
if dt_time_zone is None:
103-
return spark_types.TimestampNTZType()
111+
return native.TimestampNTZType()
104112
if dt_time_zone != "UTC": # pragma: no cover
105113
msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}"
106114
raise ValueError(msg)
107-
return spark_types.TimestampType()
115+
return native.TimestampType()
108116
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
109-
return spark_types.ArrayType(
117+
return native.ArrayType(
110118
elementType=narwhals_to_native_dtype(
111-
dtype.inner, version=version, spark_types=spark_types
119+
dtype.inner, version=version, spark_types=native
112120
)
113121
)
114122
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
115-
return spark_types.StructType(
123+
return native.StructType(
116124
fields=[
117-
spark_types.StructField(
125+
native.StructField(
118126
name=field.name,
119127
dataType=narwhals_to_native_dtype(
120-
field.dtype,
121-
version=version,
122-
spark_types=spark_types,
128+
field.dtype, version=version, spark_types=native
123129
),
124130
)
125131
for field in dtype.fields
@@ -147,7 +153,7 @@ def narwhals_to_native_dtype(
147153
def evaluate_exprs(
148154
df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr
149155
) -> list[tuple[str, Column]]:
150-
native_results: list[tuple[str, list[Column]]] = []
156+
native_results: list[tuple[str, Column]] = []
151157

152158
for expr in exprs:
153159
native_series_list = expr._call(df)

narwhals/dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import polars as pl
1919
import pyarrow as pa
2020
import pyspark.sql as pyspark_sql
21-
import sqlframe
2221
from typing_extensions import TypeGuard
2322
from typing_extensions import TypeIs
2423

2524
from narwhals._arrow.typing import ArrowChunkedArray
25+
from narwhals._spark_like.dataframe import SQLFrameDataFrame
2626
from narwhals.dataframe import DataFrame
2727
from narwhals.dataframe import LazyFrame
2828
from narwhals.series import Series
@@ -231,7 +231,7 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]:
231231
)
232232

233233

234-
def is_sqlframe_dataframe(df: Any) -> TypeIs[sqlframe.base.dataframe.BaseDataFrame]:
234+
def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]:
235235
"""Check whether `df` is a SQLFrame DataFrame without importing SQLFrame."""
236236
return bool(
237237
(sqlframe := get_sqlframe()) is not None

0 commit comments

Comments
 (0)