Skip to content

Commit a43aca0

Browse files
authored
chore: Finalize support for SQLFrame (#2038)
* WIP * WIP * ruff * getting there * duplicate column names check for pyspark * xfail unary with 2 elements * mypy, missing deps * rm direct pyspark import * ignore numpy deprecation warning * mypy * rm comment regarding the reason for not enabling sqlframe
1 parent f34ec91 commit a43aca0

File tree

16 files changed

+230
-76
lines changed

16 files changed

+230
-76
lines changed

.github/workflows/pytest.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ jobs:
5353
cache-dependency-glob: "pyproject.toml"
5454
- name: install-reqs
5555
# we are not testing pyspark on Windows here because it is very slow
56-
run: uv pip install -e ".[tests, core, extra, dask, modin]" --system
56+
run: uv pip install -e ".[tests, core, extra, dask, modin, sqlframe]" --system
5757
- name: show-deps
5858
run: uv pip freeze
5959
- name: Run pytest
6060
run: |
61-
pytest tests --cov=narwhals --cov=tests --runslow --cov-fail-under=95 --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,modin[pyarrow],polars[eager],polars[lazy],dask,duckdb
61+
pytest tests --cov=narwhals --cov=tests --runslow --cov-fail-under=95 --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,modin[pyarrow],polars[eager],polars[lazy],dask,duckdb,sqlframe
6262
6363
pytest-full-coverage:
6464
strategy:
@@ -83,7 +83,7 @@ jobs:
8383
cache-suffix: ${{ matrix.python-version }}
8484
cache-dependency-glob: "pyproject.toml"
8585
- name: install-reqs
86-
run: uv pip install -e ".[tests, core, extra, modin, dask]" --system
86+
run: uv pip install -e ".[tests, core, extra, modin, dask, sqlframe]" --system
8787
- name: install pyspark
8888
run: uv pip install -e ".[pyspark]" --system
8989
# PySpark is not yet available on Python3.12+

narwhals/_spark_like/dataframe.py

Lines changed: 114 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from importlib import import_module
45
from typing import TYPE_CHECKING
56
from typing import Any
67
from typing import Literal
@@ -13,6 +14,7 @@
1314
from narwhals.typing import CompliantLazyFrame
1415
from narwhals.utils import Implementation
1516
from narwhals.utils import check_column_exists
17+
from narwhals.utils import check_column_names_are_unique
1618
from narwhals.utils import find_stacklevel
1719
from narwhals.utils import import_dtypes_module
1820
from narwhals.utils import parse_columns_to_drop
@@ -23,6 +25,7 @@
2325
from types import ModuleType
2426

2527
import pyarrow as pa
28+
from pyspark.sql import Column
2629
from pyspark.sql import DataFrame
2730
from typing_extensions import Self
2831

@@ -41,7 +44,10 @@ def __init__(
4144
backend_version: tuple[int, ...],
4245
version: Version,
4346
implementation: Implementation,
47+
validate_column_names: bool,
4448
) -> None:
49+
if validate_column_names:
50+
check_column_names_are_unique(native_dataframe.columns)
4551
self._native_frame = native_dataframe
4652
self._backend_version = backend_version
4753
self._implementation = implementation
@@ -51,33 +57,50 @@ def __init__(
5157
@property
5258
def _F(self: Self) -> Any: # noqa: N802
5359
if self._implementation is Implementation.SQLFRAME:
54-
from sqlframe.duckdb import functions
60+
from sqlframe.base.session import _BaseSession
61+
62+
return import_module(
63+
f"sqlframe.{_BaseSession().execution_dialect_name}.functions"
64+
)
5565

56-
return functions
5766
from pyspark.sql import functions
5867

5968
return functions
6069

6170
@property
6271
def _native_dtypes(self: Self) -> Any:
6372
if self._implementation is Implementation.SQLFRAME:
64-
from sqlframe.duckdb import types
73+
from sqlframe.base.session import _BaseSession
74+
75+
return import_module(
76+
f"sqlframe.{_BaseSession().execution_dialect_name}.types"
77+
)
6578

66-
return types
6779
from pyspark.sql import types
6880

6981
return types
7082

7183
@property
7284
def _Window(self: Self) -> Any: # noqa: N802
7385
if self._implementation is Implementation.SQLFRAME:
74-
from sqlframe.duckdb import Window
86+
from sqlframe.base.session import _BaseSession
87+
88+
_window = import_module(
89+
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
90+
)
91+
return _window.Window
7592

76-
return Window
7793
from pyspark.sql import Window
7894

7995
return Window
8096

97+
@property
98+
def _session(self: Self) -> Any:
99+
if self._implementation is Implementation.SQLFRAME:
100+
return self._native_frame.session
101+
102+
return self._native_frame.sparkSession
103+
81104
def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
82105
return self._implementation.to_native_namespace()
83106

@@ -99,14 +122,18 @@ def _change_version(self: Self, version: Version) -> Self:
99122
backend_version=self._backend_version,
100123
version=version,
101124
implementation=self._implementation,
125+
validate_column_names=False,
102126
)
103127

104-
def _from_native_frame(self: Self, df: DataFrame) -> Self:
128+
def _from_native_frame(
129+
self: Self, df: DataFrame, *, validate_column_names: bool = True
130+
) -> Self:
105131
return self.__class__(
106132
df,
107133
backend_version=self._backend_version,
108134
version=self._version,
109135
implementation=self._implementation,
136+
validate_column_names=validate_column_names,
110137
)
111138

112139
def _collect_to_arrow(self) -> pa.Table:
@@ -205,7 +232,9 @@ def collect(
205232
raise ValueError(msg) # pragma: no cover
206233

207234
def simple_select(self: Self, *column_names: str) -> Self:
208-
return self._from_native_frame(self._native_frame.select(*column_names))
235+
return self._from_native_frame(
236+
self._native_frame.select(*column_names), validate_column_names=False
237+
)
209238

210239
def aggregate(
211240
self: Self,
@@ -214,7 +243,9 @@ def aggregate(
214243
new_columns = parse_exprs(self, *exprs)
215244

216245
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
217-
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
246+
return self._from_native_frame(
247+
self._native_frame.agg(*new_columns_list), validate_column_names=False
248+
)
218249

219250
def select(
220251
self: Self,
@@ -224,17 +255,18 @@ def select(
224255

225256
if not new_columns:
226257
# return empty dataframe, like Polars does
227-
spark_session = self._native_frame.sparkSession
228-
spark_df = spark_session.createDataFrame(
258+
spark_df = self._session.createDataFrame(
229259
[], self._native_dtypes.StructType([])
230260
)
231261

232-
return self._from_native_frame(spark_df)
262+
return self._from_native_frame(spark_df, validate_column_names=False)
233263

234264
new_columns_list = [
235265
col.alias(col_name) for (col_name, col) in new_columns.items()
236266
]
237-
return self._from_native_frame(self._native_frame.select(*new_columns_list))
267+
return self._from_native_frame(
268+
self._native_frame.select(*new_columns_list), validate_column_names=False
269+
)
238270

239271
def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self:
240272
new_columns = parse_exprs(self, *exprs)
@@ -244,7 +276,7 @@ def filter(self: Self, predicate: SparkLikeExpr) -> Self:
244276
# `[0]` is safe as the predicate's expression only returns a single column
245277
condition = predicate._call(self)[0]
246278
spark_df = self._native_frame.where(condition)
247-
return self._from_native_frame(spark_df)
279+
return self._from_native_frame(spark_df, validate_column_names=False)
248280

249281
@property
250282
def schema(self: Self) -> dict[str, DType]:
@@ -264,13 +296,13 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
264296
columns_to_drop = parse_columns_to_drop(
265297
compliant_frame=self, columns=columns, strict=strict
266298
)
267-
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))
299+
return self._from_native_frame(
300+
self._native_frame.drop(*columns_to_drop), validate_column_names=False
301+
)
268302

269303
def head(self: Self, n: int) -> Self:
270-
spark_session = self._native_frame.sparkSession
271-
272304
return self._from_native_frame(
273-
spark_session.createDataFrame(self._native_frame.take(num=n))
305+
self._native_frame.limit(num=n), validate_column_names=False
274306
)
275307

276308
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy:
@@ -301,10 +333,14 @@ def sort(
301333
)
302334

303335
sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)]
304-
return self._from_native_frame(self._native_frame.sort(*sort_cols))
336+
return self._from_native_frame(
337+
self._native_frame.sort(*sort_cols), validate_column_names=False
338+
)
305339

306340
def drop_nulls(self: Self, subset: list[str] | None) -> Self:
307-
return self._from_native_frame(self._native_frame.dropna(subset=subset))
341+
return self._from_native_frame(
342+
self._native_frame.dropna(subset=subset), validate_column_names=False
343+
)
308344

309345
def rename(self: Self, mapping: dict[str, str]) -> Self:
310346
rename_mapping = {
@@ -326,7 +362,9 @@ def unique(
326362
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
327363
raise ValueError(msg)
328364
check_column_exists(self.columns, subset)
329-
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))
365+
return self._from_native_frame(
366+
self._native_frame.dropDuplicates(subset=subset), validate_column_names=False
367+
)
330368

331369
def join(
332370
self: Self,
@@ -357,7 +395,7 @@ def join(
357395
for colname in list(set(right_columns).difference(set(right_on or [])))
358396
},
359397
}
360-
other = other_native.select(
398+
other_native = other_native.select(
361399
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
362400
)
363401

@@ -375,7 +413,7 @@ def join(
375413
]
376414
)
377415
return self._from_native_frame(
378-
self_native.join(other, on=left_on, how=how).select(col_order)
416+
self_native.join(other_native, on=left_on, how=how).select(col_order)
379417
)
380418

381419
def explode(self: Self, columns: list[str]) -> Self:
@@ -402,16 +440,51 @@ def explode(self: Self, columns: list[str]) -> Self:
402440
)
403441
raise NotImplementedError(msg)
404442

405-
return self._from_native_frame(
406-
native_frame.select(
407-
*[
408-
self._F.col(col_name).alias(col_name)
409-
if col_name != columns[0]
410-
else self._F.explode_outer(col_name).alias(col_name)
411-
for col_name in column_names
412-
]
443+
if self._implementation.is_pyspark():
444+
return self._from_native_frame(
445+
native_frame.select(
446+
*[
447+
self._F.col(col_name).alias(col_name)
448+
if col_name != columns[0]
449+
else self._F.explode_outer(col_name).alias(col_name)
450+
for col_name in column_names
451+
]
452+
),
453+
validate_column_names=False,
413454
)
414-
)
455+
elif self._implementation.is_sqlframe():
456+
# Not every sqlframe dialect supports `explode_outer` function
457+
# (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
458+
# therefore we simply explode the array column which will ignore nulls and
459+
# zero sized arrays, and append these specific condition with nulls (to
460+
# match polars behavior).
461+
462+
def null_condition(col_name: str) -> Column:
463+
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)
464+
465+
return self._from_native_frame(
466+
native_frame.select(
467+
*[
468+
self._F.col(col_name).alias(col_name)
469+
if col_name != columns[0]
470+
else self._F.explode(col_name).alias(col_name)
471+
for col_name in column_names
472+
]
473+
).union(
474+
native_frame.filter(null_condition(columns[0])).select(
475+
*[
476+
self._F.col(col_name).alias(col_name)
477+
if col_name != columns[0]
478+
else self._F.lit(None).alias(col_name)
479+
for col_name in column_names
480+
]
481+
)
482+
),
483+
validate_column_names=False,
484+
)
485+
else: # pragma: no cover
486+
msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues"
487+
raise AssertionError(msg)
415488

416489
def unpivot(
417490
self: Self,
@@ -420,6 +493,15 @@ def unpivot(
420493
variable_name: str,
421494
value_name: str,
422495
) -> Self:
496+
if self._implementation.is_sqlframe():
497+
if variable_name == "":
498+
msg = "`variable_name` cannot be empty string for sqlframe backend."
499+
raise NotImplementedError(msg)
500+
501+
if value_name == "":
502+
msg = "`value_name` cannot be empty string for sqlframe backend."
503+
raise NotImplementedError(msg)
504+
423505
ids = tuple(self.columns) if index is None else tuple(index)
424506
values = (
425507
tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on)

0 commit comments

Comments
 (0)