Skip to content

Commit 3c58b0e

Browse files
test: Simplify read_scan_test, spark session (#3024)
Co-authored-by: FBruzzesi <[email protected]>
1 parent 68d762a commit 3c58b0e

File tree

7 files changed

+134
-213
lines changed

7 files changed

+134
-213
lines changed

tests/conftest.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111

1212
from narwhals._utils import Implementation, generate_temporary_column_name
13-
from tests.utils import PANDAS_VERSION
13+
from tests.utils import PANDAS_VERSION, pyspark_session, sqlframe_session
1414

1515
if TYPE_CHECKING:
1616
from collections.abc import Sequence
@@ -168,35 +168,13 @@ def pyspark_lazy_constructor() -> Callable[[Data], PySparkDataFrame]: # pragma:
168168
import warnings
169169
from atexit import register
170170

171-
is_spark_connect = bool(os.environ.get("SPARK_CONNECT", None))
172-
173-
if TYPE_CHECKING:
174-
from pyspark.sql import SparkSession
175-
elif is_spark_connect:
176-
from pyspark.sql.connect.session import SparkSession
177-
else:
178-
from pyspark.sql import SparkSession
179-
180171
with warnings.catch_warnings():
181172
# The spark session seems to trigger a polars warning.
182173
# Polars is imported in the tests, but not used in the spark operations
183174
warnings.filterwarnings(
184175
"ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning
185176
)
186-
builder = cast("SparkSession.Builder", SparkSession.builder).appName("unit-tests")
187-
188-
session = (
189-
(
190-
builder.remote(f"sc://localhost:{os.environ.get('SPARK_PORT', '15002')}")
191-
if is_spark_connect
192-
else builder.master("local[1]").config("spark.ui.enabled", "false")
193-
)
194-
.config("spark.default.parallelism", "1")
195-
.config("spark.sql.shuffle.partitions", "2")
196-
# common timezone for all tests environments
197-
.config("spark.sql.session.timeZone", "UTC")
198-
.getOrCreate()
199-
)
177+
session = pyspark_session()
200178

201179
register(session.stop)
202180

@@ -216,9 +194,7 @@ def _constructor(obj: Data) -> PySparkDataFrame:
216194

217195

218196
def sqlframe_pyspark_lazy_constructor(obj: Data) -> SQLFrameDataFrame: # pragma: no cover
219-
from sqlframe.duckdb import DuckDBSession
220-
221-
session = DuckDBSession()
197+
session = sqlframe_session()
222198
return session.createDataFrame([*zip(*obj.values())], schema=[*obj.keys()])
223199

224200

tests/dtypes_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import narwhals as nw
1313
from narwhals.exceptions import PerformanceWarning
14-
from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION
14+
from tests.utils import PANDAS_VERSION, POLARS_VERSION, PYARROW_VERSION, pyspark_session
1515

1616
if TYPE_CHECKING:
1717
from collections.abc import Iterable
@@ -505,15 +505,9 @@ def test_datetime_w_tz_duckdb() -> None:
505505
assert result["b"] == nw.List(nw.List(nw.Datetime("us", "Asia/Kathmandu")))
506506

507507

508-
def test_datetime_w_tz_pyspark(constructor: Constructor) -> None: # pragma: no cover
509-
if "pyspark" not in str(constructor) or "sqlframe" in str(constructor):
510-
pytest.skip()
508+
def test_datetime_w_tz_pyspark() -> None: # pragma: no cover
511509
pytest.importorskip("pyspark")
512-
from pyspark.sql import SparkSession
513-
514-
session = SparkSession.builder.config(
515-
"spark.sql.session.timeZone", "UTC"
516-
).getOrCreate()
510+
session = pyspark_session()
517511

518512
df = nw.from_native(
519513
session.createDataFrame([(datetime(2020, 1, 1, tzinfo=timezone.utc),)], ["a"])

tests/expr_and_series/dt/convert_time_zone_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Constructor,
1414
assert_equal_data,
1515
is_windows,
16+
pyspark_session,
1617
)
1718

1819
if TYPE_CHECKING:
@@ -153,17 +154,10 @@ def test_convert_time_zone_to_connection_tz_duckdb() -> None:
153154
)
154155

155156

156-
def test_convert_time_zone_to_connection_tz_pyspark(
157-
constructor: Constructor,
158-
) -> None: # pragma: no cover
159-
if "pyspark" not in str(constructor) or "sqlframe" in str(constructor):
160-
pytest.skip()
157+
def test_convert_time_zone_to_connection_tz_pyspark() -> None: # pragma: no cover
161158
pytest.importorskip("pyspark")
162-
from pyspark.sql import SparkSession
163159

164-
session = SparkSession.builder.config(
165-
"spark.sql.session.timeZone", "UTC"
166-
).getOrCreate()
160+
session = pyspark_session()
167161
df = nw.from_native(
168162
session.createDataFrame([(datetime(2020, 1, 1, tzinfo=timezone.utc),)], ["a"])
169163
)

tests/expr_and_series/dt/replace_time_zone_test.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import pytest
88

99
import narwhals as nw
10-
from tests.utils import PANDAS_VERSION, Constructor, assert_equal_data, is_windows
10+
from tests.utils import (
11+
PANDAS_VERSION,
12+
Constructor,
13+
assert_equal_data,
14+
is_windows,
15+
pyspark_session,
16+
)
1117

1218
if TYPE_CHECKING:
1319
from tests.utils import ConstructorEager
@@ -136,17 +142,10 @@ def test_replace_time_zone_to_connection_tz_duckdb() -> None:
136142
)
137143

138144

139-
def test_replace_time_zone_to_connection_tz_pyspark(
140-
constructor: Constructor,
141-
) -> None: # pragma: no cover
142-
if "pyspark" not in str(constructor) or "sqlframe" in str(constructor):
143-
pytest.skip()
145+
def test_replace_time_zone_to_connection_tz_pyspark() -> None: # pragma: no cover
144146
pytest.importorskip("pyspark")
145-
from pyspark.sql import SparkSession
146147

147-
session = SparkSession.builder.config(
148-
"spark.sql.session.timeZone", "UTC"
149-
).getOrCreate()
148+
session = pyspark_session()
150149
df = nw.from_native(
151150
session.createDataFrame([(datetime(2020, 1, 1, tzinfo=timezone.utc),)], ["a"])
152151
)

0 commit comments

Comments
 (0)