11from __future__ import annotations
22
33import warnings
4- from importlib import import_module
54from typing import TYPE_CHECKING
65from typing import Any
76from typing import Iterator
87from typing import Literal
98from typing import Sequence
10- from typing import cast
119
1210from narwhals ._spark_like .utils import evaluate_exprs
11+ from narwhals ._spark_like .utils import import_functions
12+ from narwhals ._spark_like .utils import import_native_dtypes
13+ from narwhals ._spark_like .utils import import_window
1314from narwhals ._spark_like .utils import native_to_narwhals_dtype
1415from narwhals .exceptions import InvalidOperationError
1516from narwhals .typing import CompliantDataFrame
2627 from types import ModuleType
2728
2829 import pyarrow as pa
29- from pyspark .sql import Column
30- from pyspark .sql import DataFrame
31- from pyspark .sql import Window
32- from pyspark .sql .session import SparkSession
33- from sqlframe .base .dataframe import BaseDataFrame as _SQLFrameDataFrame
30+ from sqlframe .base .column import Column
31+ from sqlframe .base .dataframe import BaseDataFrame
32+ from sqlframe .base .session import _BaseSession
33+ from sqlframe .base .window import Window
3434 from typing_extensions import Self
3535 from typing_extensions import TypeAlias
3636
4040 from narwhals .dtypes import DType
4141 from narwhals .utils import Version
4242
43- SQLFrameDataFrame : TypeAlias = _SQLFrameDataFrame [Any , Any , Any , Any , Any ]
44- _NativeDataFrame : TypeAlias = "DataFrame | SQLFrameDataFrame"
43+ SQLFrameDataFrame = BaseDataFrame [Any , Any , Any , Any , Any ]
44+ SQLFrameSession = _BaseSession [ Any , Any , Any , Any , Any , Any , Any ]
4545
4646Incomplete : TypeAlias = Any # pragma: no cover
4747"""Marker for working code that fails type checking."""
5050class SparkLikeLazyFrame (CompliantLazyFrame ):
5151 def __init__ (
5252 self : Self ,
53- native_dataframe : _NativeDataFrame ,
53+ native_dataframe : SQLFrameDataFrame ,
5454 * ,
5555 backend_version : tuple [int , ...],
5656 version : Version ,
5757 implementation : Implementation ,
5858 # Unused, just for compatibility. We only validate when collecting.
5959 validate_column_names : bool = False ,
6060 ) -> None :
61- self ._native_frame = native_dataframe
61+ self ._native_frame : SQLFrameDataFrame = native_dataframe
6262 self ._backend_version = backend_version
6363 self ._implementation = implementation
6464 self ._version = version
@@ -68,58 +68,38 @@ def __init__(
6868 @property
6969 def _F (self : Self ): # type: ignore[no-untyped-def] # noqa: ANN202, N802
7070 if TYPE_CHECKING :
71- from pyspark . sql import functions
71+ from sqlframe . base import functions
7272
7373 return functions
74- if self ._implementation is Implementation .SQLFRAME :
75- from sqlframe .base .session import _BaseSession
76-
77- return import_module (
78- f"sqlframe.{ _BaseSession ().execution_dialect_name } .functions"
79- )
80-
81- from pyspark .sql import functions
82-
83- return functions
74+ else :
75+ return import_functions (self ._implementation )
8476
8577 @property
8678 def _native_dtypes (self : Self ): # type: ignore[no-untyped-def] # noqa: ANN202
8779 if TYPE_CHECKING :
88- from pyspark . sql import types
80+ from sqlframe . base import types
8981
9082 return types
91-
92- if self ._implementation is Implementation .SQLFRAME :
93- from sqlframe .base .session import _BaseSession
94-
95- return import_module (
96- f"sqlframe.{ _BaseSession ().execution_dialect_name } .types"
97- )
98-
99- from pyspark .sql import types
100-
101- return types
83+ else :
84+ return import_native_dtypes (self ._implementation )
10285
10386 @property
10487 def _Window (self : Self ) -> type [Window ]: # noqa: N802
105- if self ._implementation is Implementation .SQLFRAME :
106- from sqlframe .base .session import _BaseSession
107-
108- _window = import_module (
109- f"sqlframe.{ _BaseSession ().execution_dialect_name } .window"
110- )
111- return _window .Window
112-
113- from pyspark .sql import Window
88+ if TYPE_CHECKING :
89+ from sqlframe .base .window import Window
11490
115- return Window
91+ return Window
92+ else :
93+ return import_window (self ._implementation )
11694
11795 @property
118- def _session (self : Self ) -> SparkSession :
96+ def _session (self : Self ) -> SQLFrameSession :
97+ if TYPE_CHECKING :
98+ return self ._native_frame .session
11999 if self ._implementation is Implementation .SQLFRAME :
120- return cast ( "SQLFrameDataFrame" , self ._native_frame ) .session
100+ return self ._native_frame .session
121101
122- return cast ( "DataFrame" , self ._native_frame ) .sparkSession
102+ return self ._native_frame .sparkSession
123103
124104 def __native_namespace__ (self : Self ) -> ModuleType : # pragma: no cover
125105 return self ._implementation .to_native_namespace ()
@@ -144,7 +124,7 @@ def _change_version(self: Self, version: Version) -> Self:
144124 implementation = self ._implementation ,
145125 )
146126
147- def _from_native_frame (self : Self , df : DataFrame ) -> Self :
127+ def _from_native_frame (self : Self , df : SQLFrameDataFrame ) -> Self :
148128 return self .__class__ (
149129 df ,
150130 backend_version = self ._backend_version ,
@@ -158,7 +138,7 @@ def _collect_to_arrow(self) -> pa.Table:
158138 ):
159139 import pyarrow as pa # ignore-banned-import
160140
161- native_frame = cast ( "DataFrame" , self ._native_frame )
141+ native_frame = self ._native_frame
162142 try :
163143 return pa .Table .from_batches (native_frame ._collect_as_arrow ())
164144 except ValueError as exc :
@@ -174,13 +154,12 @@ def _collect_to_arrow(self) -> pa.Table:
174154 try :
175155 native_dtype = narwhals_to_native_dtype (value , self ._version )
176156 except Exception as exc : # noqa: BLE001
177- native_spark_dtype = native_frame .schema [key ].dataType
157+ native_spark_dtype = native_frame .schema [key ].dataType # type: ignore[index]
178158 # If we can't convert the type, just set it to `pa.null`, and warn.
179159 # Avoid the warning if we're starting from PySpark's void type.
180160 # We can avoid the check when we introduce `nw.Null` dtype.
181- if not isinstance (
182- native_spark_dtype , self ._native_dtypes .NullType
183- ):
161+ null_type = self ._native_dtypes .NullType # pyright: ignore[reportAttributeAccessIssue]
162+ if not isinstance (native_spark_dtype , null_type ):
184163 warnings .warn (
185164 f"Could not convert dtype { native_spark_dtype } to PyArrow dtype, { exc !r} " ,
186165 stacklevel = find_stacklevel (),
@@ -192,9 +171,7 @@ def _collect_to_arrow(self) -> pa.Table:
192171 else : # pragma: no cover
193172 raise
194173 else :
195- # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1969224309
196- to_arrow : Incomplete = self ._native_frame .toArrow
197- return to_arrow ()
174+ return self ._native_frame .toArrow ()
198175
199176 def _iter_columns (self ) -> Iterator [Column ]:
200177 for col in self .columns :
@@ -250,7 +227,7 @@ def collect(
250227 raise ValueError (msg ) # pragma: no cover
251228
252229 def simple_select (self : Self , * column_names : str ) -> Self :
253- return self ._from_native_frame (self ._native_frame .select (* column_names )) # pyright: ignore[reportArgumentType]
230+ return self ._from_native_frame (self ._native_frame .select (* column_names ))
254231
255232 def aggregate (
256233 self : Self ,
@@ -259,7 +236,7 @@ def aggregate(
259236 new_columns = evaluate_exprs (self , * exprs )
260237
261238 new_columns_list = [col .alias (col_name ) for col_name , col in new_columns ]
262- return self ._from_native_frame (self ._native_frame .agg (* new_columns_list )) # pyright: ignore[reportArgumentType]
239+ return self ._from_native_frame (self ._native_frame .agg (* new_columns_list ))
263240
264241 def select (
265242 self : Self ,
@@ -274,17 +251,17 @@ def select(
274251 return self ._from_native_frame (spark_df )
275252
276253 new_columns_list = [col .alias (col_name ) for (col_name , col ) in new_columns ]
277- return self ._from_native_frame (self ._native_frame .select (* new_columns_list )) # pyright: ignore[reportArgumentType]
254+ return self ._from_native_frame (self ._native_frame .select (* new_columns_list ))
278255
279256 def with_columns (self : Self , * exprs : SparkLikeExpr ) -> Self :
280257 new_columns = evaluate_exprs (self , * exprs )
281- return self ._from_native_frame (self ._native_frame .withColumns (dict (new_columns ))) # pyright: ignore[reportArgumentType]
258+ return self ._from_native_frame (self ._native_frame .withColumns (dict (new_columns )))
282259
283260 def filter (self : Self , predicate : SparkLikeExpr ) -> Self :
284261 # `[0]` is safe as the predicate's expression only returns a single column
285262 condition = predicate ._call (self )[0 ]
286- spark_df = self ._native_frame .where (condition ) # pyright: ignore[reportArgumentType]
287- return self ._from_native_frame (spark_df ) # pyright: ignore[reportArgumentType]
263+ spark_df = self ._native_frame .where (condition )
264+ return self ._from_native_frame (spark_df )
288265
289266 @property
290267 def schema (self : Self ) -> dict [str , DType ]:
@@ -293,8 +270,7 @@ def schema(self: Self) -> dict[str, DType]:
293270 field .name : native_to_narwhals_dtype (
294271 dtype = field .dataType ,
295272 version = self ._version ,
296- # NOTE: Unclear if this is an unsafe hash (https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1970074662)
297- spark_types = self ._native_dtypes , # pyright: ignore[reportArgumentType]
273+ spark_types = self ._native_dtypes ,
298274 )
299275 for field in self ._native_frame .schema
300276 }
@@ -307,10 +283,10 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
307283 columns_to_drop = parse_columns_to_drop (
308284 compliant_frame = self , columns = columns , strict = strict
309285 )
310- return self ._from_native_frame (self ._native_frame .drop (* columns_to_drop )) # pyright: ignore[reportArgumentType]
286+ return self ._from_native_frame (self ._native_frame .drop (* columns_to_drop ))
311287
312288 def head (self : Self , n : int ) -> Self :
313- return self ._from_native_frame (self ._native_frame .limit (num = n )) # pyright: ignore[reportArgumentType]
289+ return self ._from_native_frame (self ._native_frame .limit (num = n ))
314290
315291 def group_by (self : Self , * keys : str , drop_null_keys : bool ) -> SparkLikeLazyGroupBy :
316292 from narwhals ._spark_like .group_by import SparkLikeLazyGroupBy
@@ -340,18 +316,18 @@ def sort(
340316 )
341317
342318 sort_cols = [sort_f (col ) for col , sort_f in zip (by , sort_funcs )]
343- return self ._from_native_frame (self ._native_frame .sort (* sort_cols )) # pyright: ignore[reportArgumentType]
319+ return self ._from_native_frame (self ._native_frame .sort (* sort_cols ))
344320
345321 def drop_nulls (self : Self , subset : list [str ] | None ) -> Self :
346- return self ._from_native_frame (self ._native_frame .dropna (subset = subset )) # pyright: ignore[reportArgumentType]
322+ return self ._from_native_frame (self ._native_frame .dropna (subset = subset ))
347323
348324 def rename (self : Self , mapping : dict [str , str ]) -> Self :
349325 rename_mapping = {
350326 colname : mapping .get (colname , colname ) for colname in self .columns
351327 }
352328 return self ._from_native_frame (
353329 self ._native_frame .select (
354- [self ._F .col (old ).alias (new ) for old , new in rename_mapping .items ()] # pyright: ignore[reportArgumentType]
330+ [self ._F .col (old ).alias (new ) for old , new in rename_mapping .items ()]
355331 )
356332 )
357333
@@ -365,7 +341,7 @@ def unique(
365341 msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
366342 raise ValueError (msg )
367343 check_column_exists (self .columns , subset )
368- return self ._from_native_frame (self ._native_frame .dropDuplicates (subset = subset )) # pyright: ignore[reportArgumentType]
344+ return self ._from_native_frame (self ._native_frame .dropDuplicates (subset = subset ))
369345
370346 def join (
371347 self : Self ,
@@ -409,7 +385,7 @@ def join(
409385 ]
410386 )
411387 return self ._from_native_frame (
412- self_native .join (other_native , on = left_on , how = how ).select (col_order ) # pyright: ignore[reportArgumentType]
388+ self_native .join (other_native , on = left_on , how = how ).select (col_order )
413389 )
414390
415391 def explode (self : Self , columns : list [str ]) -> Self :
@@ -445,7 +421,7 @@ def explode(self: Self, columns: list[str]) -> Self:
445421 else self ._F .explode_outer (col_name ).alias (col_name )
446422 for col_name in column_names
447423 ]
448- ), # pyright: ignore[reportArgumentType]
424+ )
449425 )
450426 elif self ._implementation .is_sqlframe ():
451427 # Not every sqlframe dialect supports `explode_outer` function
@@ -466,14 +442,14 @@ def null_condition(col_name: str) -> Column:
466442 for col_name in column_names
467443 ]
468444 ).union (
469- native_frame .filter (null_condition (columns [0 ])).select ( # pyright: ignore[reportArgumentType]
445+ native_frame .filter (null_condition (columns [0 ])).select (
470446 * [
471447 self ._F .col (col_name ).alias (col_name )
472448 if col_name != columns [0 ]
473449 else self ._F .lit (None ).alias (col_name )
474450 for col_name in column_names
475451 ]
476- ) # pyright: ignore[reportArgumentType]
452+ )
477453 ),
478454 )
479455 else : # pragma: no cover
@@ -508,4 +484,4 @@ def unpivot(
508484 )
509485 if index is None :
510486 unpivoted_native_frame = unpivoted_native_frame .drop (* ids )
511- return self ._from_native_frame (unpivoted_native_frame ) # pyright: ignore[reportArgumentType]
487+ return self ._from_native_frame (unpivoted_native_frame )
0 commit comments