11from __future__ import annotations
22
33import warnings
4+ from importlib import import_module
45from typing import TYPE_CHECKING
56from typing import Any
67from typing import Literal
1314from narwhals .typing import CompliantLazyFrame
1415from narwhals .utils import Implementation
1516from narwhals .utils import check_column_exists
17+ from narwhals .utils import check_column_names_are_unique
1618from narwhals .utils import find_stacklevel
1719from narwhals .utils import import_dtypes_module
1820from narwhals .utils import parse_columns_to_drop
2325 from types import ModuleType
2426
2527 import pyarrow as pa
28+ from pyspark .sql import Column
2629 from pyspark .sql import DataFrame
2730 from typing_extensions import Self
2831
@@ -41,7 +44,10 @@ def __init__(
4144 backend_version : tuple [int , ...],
4245 version : Version ,
4346 implementation : Implementation ,
47+ validate_column_names : bool ,
4448 ) -> None :
49+ if validate_column_names :
50+ check_column_names_are_unique (native_dataframe .columns )
4551 self ._native_frame = native_dataframe
4652 self ._backend_version = backend_version
4753 self ._implementation = implementation
@@ -51,33 +57,50 @@ def __init__(
5157 @property
5258 def _F (self : Self ) -> Any : # noqa: N802
5359 if self ._implementation is Implementation .SQLFRAME :
54- from sqlframe .duckdb import functions
60+ from sqlframe .base .session import _BaseSession
61+
62+ return import_module (
63+ f"sqlframe.{ _BaseSession ().execution_dialect_name } .functions"
64+ )
5565
56- return functions
5766 from pyspark .sql import functions
5867
5968 return functions
6069
6170 @property
6271 def _native_dtypes (self : Self ) -> Any :
6372 if self ._implementation is Implementation .SQLFRAME :
64- from sqlframe .duckdb import types
73+ from sqlframe .base .session import _BaseSession
74+
75+ return import_module (
76+ f"sqlframe.{ _BaseSession ().execution_dialect_name } .types"
77+ )
6578
66- return types
6779 from pyspark .sql import types
6880
6981 return types
7082
7183 @property
7284 def _Window (self : Self ) -> Any : # noqa: N802
7385 if self ._implementation is Implementation .SQLFRAME :
74- from sqlframe .duckdb import Window
86+ from sqlframe .base .session import _BaseSession
87+
88+ _window = import_module (
89+ f"sqlframe.{ _BaseSession ().execution_dialect_name } .window"
90+ )
91+ return _window .Window
7592
76- return Window
7793 from pyspark .sql import Window
7894
7995 return Window
8096
97+ @property
98+ def _session (self : Self ) -> Any :
99+ if self ._implementation is Implementation .SQLFRAME :
100+ return self ._native_frame .session
101+
102+ return self ._native_frame .sparkSession
103+
81104 def __native_namespace__ (self : Self ) -> ModuleType : # pragma: no cover
82105 return self ._implementation .to_native_namespace ()
83106
@@ -99,14 +122,18 @@ def _change_version(self: Self, version: Version) -> Self:
99122 backend_version = self ._backend_version ,
100123 version = version ,
101124 implementation = self ._implementation ,
125+ validate_column_names = False ,
102126 )
103127
104- def _from_native_frame (self : Self , df : DataFrame ) -> Self :
128+ def _from_native_frame (
129+ self : Self , df : DataFrame , * , validate_column_names : bool = True
130+ ) -> Self :
105131 return self .__class__ (
106132 df ,
107133 backend_version = self ._backend_version ,
108134 version = self ._version ,
109135 implementation = self ._implementation ,
136+ validate_column_names = validate_column_names ,
110137 )
111138
112139 def _collect_to_arrow (self ) -> pa .Table :
@@ -205,7 +232,9 @@ def collect(
205232 raise ValueError (msg ) # pragma: no cover
206233
207234 def simple_select (self : Self , * column_names : str ) -> Self :
208- return self ._from_native_frame (self ._native_frame .select (* column_names ))
235+ return self ._from_native_frame (
236+ self ._native_frame .select (* column_names ), validate_column_names = False
237+ )
209238
210239 def aggregate (
211240 self : Self ,
@@ -214,7 +243,9 @@ def aggregate(
214243 new_columns = parse_exprs (self , * exprs )
215244
216245 new_columns_list = [col .alias (col_name ) for col_name , col in new_columns .items ()]
217- return self ._from_native_frame (self ._native_frame .agg (* new_columns_list ))
246+ return self ._from_native_frame (
247+ self ._native_frame .agg (* new_columns_list ), validate_column_names = False
248+ )
218249
219250 def select (
220251 self : Self ,
@@ -224,17 +255,18 @@ def select(
224255
225256 if not new_columns :
226257 # return empty dataframe, like Polars does
227- spark_session = self ._native_frame .sparkSession
228- spark_df = spark_session .createDataFrame (
258+ spark_df = self ._session .createDataFrame (
229259 [], self ._native_dtypes .StructType ([])
230260 )
231261
232- return self ._from_native_frame (spark_df )
262+ return self ._from_native_frame (spark_df , validate_column_names = False )
233263
234264 new_columns_list = [
235265 col .alias (col_name ) for (col_name , col ) in new_columns .items ()
236266 ]
237- return self ._from_native_frame (self ._native_frame .select (* new_columns_list ))
267+ return self ._from_native_frame (
268+ self ._native_frame .select (* new_columns_list ), validate_column_names = False
269+ )
238270
239271 def with_columns (self : Self , * exprs : SparkLikeExpr ) -> Self :
240272 new_columns = parse_exprs (self , * exprs )
@@ -244,7 +276,7 @@ def filter(self: Self, predicate: SparkLikeExpr) -> Self:
244276 # `[0]` is safe as the predicate's expression only returns a single column
245277 condition = predicate ._call (self )[0 ]
246278 spark_df = self ._native_frame .where (condition )
247- return self ._from_native_frame (spark_df )
279+ return self ._from_native_frame (spark_df , validate_column_names = False )
248280
249281 @property
250282 def schema (self : Self ) -> dict [str , DType ]:
@@ -264,13 +296,13 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
264296 columns_to_drop = parse_columns_to_drop (
265297 compliant_frame = self , columns = columns , strict = strict
266298 )
267- return self ._from_native_frame (self ._native_frame .drop (* columns_to_drop ))
299+ return self ._from_native_frame (
300+ self ._native_frame .drop (* columns_to_drop ), validate_column_names = False
301+ )
268302
269303 def head (self : Self , n : int ) -> Self :
270- spark_session = self ._native_frame .sparkSession
271-
272304 return self ._from_native_frame (
273- spark_session . createDataFrame ( self ._native_frame .take (num = n ))
305+ self ._native_frame .limit (num = n ), validate_column_names = False
274306 )
275307
276308 def group_by (self : Self , * keys : str , drop_null_keys : bool ) -> SparkLikeLazyGroupBy :
@@ -301,10 +333,14 @@ def sort(
301333 )
302334
303335 sort_cols = [sort_f (col ) for col , sort_f in zip (by , sort_funcs )]
304- return self ._from_native_frame (self ._native_frame .sort (* sort_cols ))
336+ return self ._from_native_frame (
337+ self ._native_frame .sort (* sort_cols ), validate_column_names = False
338+ )
305339
306340 def drop_nulls (self : Self , subset : list [str ] | None ) -> Self :
307- return self ._from_native_frame (self ._native_frame .dropna (subset = subset ))
341+ return self ._from_native_frame (
342+ self ._native_frame .dropna (subset = subset ), validate_column_names = False
343+ )
308344
309345 def rename (self : Self , mapping : dict [str , str ]) -> Self :
310346 rename_mapping = {
@@ -326,7 +362,9 @@ def unique(
326362 msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
327363 raise ValueError (msg )
328364 check_column_exists (self .columns , subset )
329- return self ._from_native_frame (self ._native_frame .dropDuplicates (subset = subset ))
365+ return self ._from_native_frame (
366+ self ._native_frame .dropDuplicates (subset = subset ), validate_column_names = False
367+ )
330368
331369 def join (
332370 self : Self ,
@@ -357,7 +395,7 @@ def join(
357395 for colname in list (set (right_columns ).difference (set (right_on or [])))
358396 },
359397 }
360- other = other_native .select (
398+ other_native = other_native .select (
361399 [self ._F .col (old ).alias (new ) for old , new in rename_mapping .items ()]
362400 )
363401
@@ -375,7 +413,7 @@ def join(
375413 ]
376414 )
377415 return self ._from_native_frame (
378- self_native .join (other , on = left_on , how = how ).select (col_order )
416+ self_native .join (other_native , on = left_on , how = how ).select (col_order )
379417 )
380418
381419 def explode (self : Self , columns : list [str ]) -> Self :
@@ -402,16 +440,51 @@ def explode(self: Self, columns: list[str]) -> Self:
402440 )
403441 raise NotImplementedError (msg )
404442
405- return self ._from_native_frame (
406- native_frame .select (
407- * [
408- self ._F .col (col_name ).alias (col_name )
409- if col_name != columns [0 ]
410- else self ._F .explode_outer (col_name ).alias (col_name )
411- for col_name in column_names
412- ]
443+ if self ._implementation .is_pyspark ():
444+ return self ._from_native_frame (
445+ native_frame .select (
446+ * [
447+ self ._F .col (col_name ).alias (col_name )
448+ if col_name != columns [0 ]
449+ else self ._F .explode_outer (col_name ).alias (col_name )
450+ for col_name in column_names
451+ ]
452+ ),
453+ validate_column_names = False ,
413454 )
414- )
455+ elif self ._implementation .is_sqlframe ():
456+ # Not every sqlframe dialect supports `explode_outer` function
457+ # (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
458+ # therefore we simply explode the array column which will ignore nulls and
459+ # zero sized arrays, and append these specific condition with nulls (to
460+ # match polars behavior).
461+
462+ def null_condition (col_name : str ) -> Column :
463+ return self ._F .isnull (col_name ) | (self ._F .array_size (col_name ) == 0 )
464+
465+ return self ._from_native_frame (
466+ native_frame .select (
467+ * [
468+ self ._F .col (col_name ).alias (col_name )
469+ if col_name != columns [0 ]
470+ else self ._F .explode (col_name ).alias (col_name )
471+ for col_name in column_names
472+ ]
473+ ).union (
474+ native_frame .filter (null_condition (columns [0 ])).select (
475+ * [
476+ self ._F .col (col_name ).alias (col_name )
477+ if col_name != columns [0 ]
478+ else self ._F .lit (None ).alias (col_name )
479+ for col_name in column_names
480+ ]
481+ )
482+ ),
483+ validate_column_names = False ,
484+ )
485+ else : # pragma: no cover
486+ msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues"
487+ raise AssertionError (msg )
415488
416489 def unpivot (
417490 self : Self ,
@@ -420,6 +493,15 @@ def unpivot(
420493 variable_name : str ,
421494 value_name : str ,
422495 ) -> Self :
496+ if self ._implementation .is_sqlframe ():
497+ if variable_name == "" :
498+ msg = "`variable_name` cannot be empty string for sqlframe backend."
499+ raise NotImplementedError (msg )
500+
501+ if value_name == "" :
502+ msg = "`value_name` cannot be empty string for sqlframe backend."
503+ raise NotImplementedError (msg )
504+
423505 ids = tuple (self .columns ) if index is None else tuple (index )
424506 values = (
425507 tuple (set (self .columns ).difference (set (ids ))) if on is None else tuple (on )
0 commit comments