Skip to content

Commit cf82f73

Browse files
enh: reflect connection time zone in Datetime for pyspark, support convert_time_zone and replace_time_zone (#2592)
--------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5ebc246 commit cf82f73

File tree

6 files changed

+129
-15
lines changed

6 files changed

+129
-15
lines changed

narwhals/_spark_like/dataframe.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,10 @@ def schema(self) -> dict[str, DType]:
281281
if self._cached_schema is None:
282282
self._cached_schema = {
283283
field.name: native_to_narwhals_dtype(
284-
dtype=field.dataType,
285-
version=self._version,
286-
spark_types=self._native_dtypes,
284+
field.dataType,
285+
self._version,
286+
self._native_dtypes,
287+
self.native.sparkSession,
287288
)
288289
for field in self.native.schema
289290
}

narwhals/_spark_like/expr_dt.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Sequence
44

55
from narwhals._duration import parse_interval_string
6-
from narwhals._spark_like.utils import UNITS_DICT, strptime_to_pyspark_format
6+
from narwhals._spark_like.utils import (
7+
UNITS_DICT,
8+
fetch_session_time_zone,
9+
strptime_to_pyspark_format,
10+
)
711

812
if TYPE_CHECKING:
913
from sqlframe.base.column import Column
1014

15+
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
1116
from narwhals._spark_like.expr import SparkLikeExpr
1217

1318

@@ -115,14 +120,40 @@ def _truncate(expr: Column) -> Column:
115120

116121
return self._compliant_expr._with_callable(_truncate)
117122

118-
def replace_time_zone(self, time_zone: str | None) -> SparkLikeExpr:
123+
def _no_op_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
124+
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
125+
native_series_list = self._compliant_expr(df)
126+
conn_time_zone = fetch_session_time_zone(df.native.sparkSession)
127+
if conn_time_zone != time_zone:
128+
msg = (
129+
"PySpark stores the time zone in the session, rather than in the "
130+
f"data type, so changing the timezone to anything other than {conn_time_zone} "
131+
" (the current session time zone) is not supported."
132+
)
133+
raise NotImplementedError(msg)
134+
return native_series_list
135+
136+
return self._compliant_expr.__class__(
137+
func,
138+
evaluate_output_names=self._compliant_expr._evaluate_output_names,
139+
alias_output_names=self._compliant_expr._alias_output_names,
140+
backend_version=self._compliant_expr._backend_version,
141+
version=self._compliant_expr._version,
142+
implementation=self._compliant_expr._implementation,
143+
)
144+
145+
def convert_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
146+
return self._no_op_time_zone(time_zone)
147+
148+
def replace_time_zone(
149+
self, time_zone: str | None
150+
) -> SparkLikeExpr: # pragma: no cover
119151
if time_zone is None:
120152
return self._compliant_expr._with_callable(
121153
lambda _input: _input.cast("timestamp_ntz")
122154
)
123-
else: # pragma: no cover
124-
msg = "`replace_time_zone` with non-null `time_zone` not yet implemented for spark-like"
125-
raise NotImplementedError(msg)
155+
else:
156+
return self._no_op_time_zone(time_zone)
126157

127158
def _format_iso_week_with_day(self, _input: Column) -> Column:
128159
"""Format datetime as ISO week string with day."""

narwhals/_spark_like/utils.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import lru_cache
34
from importlib import import_module
45
from typing import TYPE_CHECKING, Any, Sequence, overload
56

@@ -11,6 +12,7 @@
1112

1213
import sqlframe.base.types as sqlframe_types
1314
from sqlframe.base.column import Column
15+
from sqlframe.base.session import _BaseSession as Session
1416
from typing_extensions import TypeAlias
1517

1618
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
@@ -19,6 +21,7 @@
1921
from narwhals.utils import Version
2022

2123
_NativeDType: TypeAlias = sqlframe_types.DataType
24+
SparkSession = Session[Any, Any, Any, Any, Any, Any, Any]
2225

2326
UNITS_DICT = {
2427
"y": "year",
@@ -75,7 +78,7 @@ def __init__(self, expr: Column, partition_by: Sequence[str | Column]) -> None:
7578

7679
# NOTE: don't lru_cache this as `ModuleType` isn't hashable
7780
def native_to_narwhals_dtype( # noqa: C901, PLR0912
78-
dtype: _NativeDType, version: Version, spark_types: ModuleType
81+
dtype: _NativeDType, version: Version, spark_types: ModuleType, session: SparkSession
7982
) -> DType:
8083
dtypes = version.dtypes
8184
if TYPE_CHECKING:
@@ -105,16 +108,14 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
105108
# TODO(marco): cover this
106109
return dtypes.Datetime() # pragma: no cover
107110
if isinstance(dtype, native.TimestampType):
108-
# TODO(marco): is UTC correct, or should we be getting the connection timezone?
109-
# https://github.com/narwhals-dev/narwhals/issues/2165
110-
return dtypes.Datetime(time_zone="UTC")
111+
return dtypes.Datetime(time_zone=fetch_session_time_zone(session))
111112
if isinstance(dtype, native.DecimalType):
112113
# TODO(marco): cover this
113114
return dtypes.Decimal() # pragma: no cover
114115
if isinstance(dtype, native.ArrayType):
115116
return dtypes.List(
116117
inner=native_to_narwhals_dtype(
117-
dtype.elementType, version=version, spark_types=spark_types
118+
dtype.elementType, version, spark_types, session
118119
)
119120
)
120121
if isinstance(dtype, native.StructType):
@@ -123,7 +124,7 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
123124
dtypes.Field(
124125
name=field.name,
125126
dtype=native_to_narwhals_dtype(
126-
field.dataType, version=version, spark_types=spark_types
127+
field.dataType, version, spark_types, session
127128
),
128129
)
129130
for field in dtype
@@ -134,6 +135,16 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
134135
return dtypes.Unknown() # pragma: no cover
135136

136137

138+
@lru_cache(maxsize=4)
139+
def fetch_session_time_zone(session: SparkSession) -> str:
140+
# Timezone can't be changed in PySpark session, so this can be cached.
141+
try:
142+
return session.conf.get("spark.sql.session.timeZone") # type: ignore[attr-defined]
143+
except Exception: # noqa: BLE001
144+
# https://github.com/eakmanrq/sqlframe/issues/406
145+
return "<unknown>"
146+
147+
137148
def narwhals_to_native_dtype( # noqa: C901, PLR0912
138149
dtype: DType | type[DType], version: Version, spark_types: ModuleType
139150
) -> _NativeDType:

tests/dtypes_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,26 @@ def test_datetime_w_tz_duckdb() -> None:
482482
result = df.collect_schema()
483483
assert result["a"] == nw.Datetime("us", "Asia/Kathmandu")
484484
assert result["b"] == nw.List(nw.List(nw.Datetime("us", "Asia/Kathmandu")))
485+
486+
487+
def test_datetime_w_tz_pyspark(constructor: Constructor) -> None: # pragma: no cover
488+
if "pyspark" not in str(constructor):
489+
pytest.skip()
490+
pytest.importorskip("pyspark")
491+
pytest.importorskip("zoneinfo")
492+
from pyspark.sql import SparkSession
493+
494+
session = SparkSession.builder.config(
495+
"spark.sql.session.timeZone", "UTC"
496+
).getOrCreate()
497+
498+
df = nw.from_native(
499+
session.createDataFrame([(datetime(2020, 1, 1, tzinfo=timezone.utc),)], ["a"])
500+
)
501+
result = df.collect_schema()
502+
assert result["a"] == nw.Datetime("us", "UTC")
503+
df = nw.from_native(
504+
session.createDataFrame([([datetime(2020, 1, 1, tzinfo=timezone.utc)],)], ["a"])
505+
)
506+
result = df.collect_schema()
507+
assert result["a"] == nw.List(nw.Datetime("us", "UTC"))

tests/expr_and_series/dt/convert_time_zone_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,27 @@ def test_convert_time_zone_to_connection_tz_duckdb() -> None:
153153
result = nw.from_native(rel).with_columns(
154154
nw.col("a").dt.convert_time_zone("Asia/Kathmandu")
155155
)
156+
157+
158+
def test_convert_time_zone_to_connection_tz_pyspark(
159+
constructor: Constructor,
160+
) -> None: # pragma: no cover
161+
if "pyspark" not in str(constructor):
162+
pytest.skip()
163+
pytest.importorskip("pyspark")
164+
pytest.importorskip("zoneinfo")
165+
from pyspark.sql import SparkSession
166+
167+
session = SparkSession.builder.config(
168+
"spark.sql.session.timeZone", "UTC"
169+
).getOrCreate()
170+
df = nw.from_native(
171+
session.createDataFrame([(datetime(2020, 1, 1, tzinfo=timezone.utc),)], ["a"])
172+
)
173+
result = nw.from_native(df).with_columns(nw.col("a").dt.convert_time_zone("UTC"))
174+
expected = {"a": [datetime(2020, 1, 1, tzinfo=timezone.utc)]}
175+
assert_equal_data(result, expected)
176+
with pytest.raises(NotImplementedError):
177+
result = nw.from_native(df).with_columns(
178+
nw.col("a").dt.convert_time_zone("Asia/Kathmandu")
179+
)

tests/expr_and_series/dt/replace_time_zone_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,27 @@ def test_replace_time_zone_to_connection_tz_duckdb() -> None:
144144
result = nw.from_native(rel).with_columns(
145145
nw.col("a").dt.replace_time_zone("Asia/Kathmandu")
146146
)
147+
148+
149+
def test_replace_time_zone_to_connection_tz_pyspark(
150+
constructor: Constructor,
151+
) -> None: # pragma: no cover
152+
if "pyspark" not in str(constructor):
153+
pytest.skip()
154+
pytest.importorskip("pyspark")
155+
pytest.importorskip("zoneinfo")
156+
from pyspark.sql import SparkSession
157+
158+
session = SparkSession.builder.config(
159+
"spark.sql.session.timeZone", "UTC"
160+
).getOrCreate()
161+
df = nw.from_native(
162+
session.createDataFrame([(datetime(2020, 1, 1, tzinfo=timezone.utc),)], ["a"])
163+
)
164+
result = nw.from_native(df).with_columns(nw.col("a").dt.replace_time_zone("UTC"))
165+
expected = {"a": [datetime(2020, 1, 1, tzinfo=timezone.utc)]}
166+
assert_equal_data(result, expected)
167+
with pytest.raises(NotImplementedError):
168+
result = nw.from_native(df).with_columns(
169+
nw.col("a").dt.replace_time_zone("Asia/Kathmandu")
170+
)

0 commit comments

Comments
 (0)