11from __future__ import annotations
22
3+ from functools import lru_cache
34from importlib import import_module
45from typing import TYPE_CHECKING , Any , Sequence , overload
56
1112
1213 import sqlframe .base .types as sqlframe_types
1314 from sqlframe .base .column import Column
15+ from sqlframe .base .session import _BaseSession as Session
1416 from typing_extensions import TypeAlias
1517
1618 from narwhals ._spark_like .dataframe import SparkLikeLazyFrame
1921 from narwhals .utils import Version
2022
2123 _NativeDType : TypeAlias = sqlframe_types .DataType
24+ SparkSession = Session [Any , Any , Any , Any , Any , Any , Any ]
2225
2326UNITS_DICT = {
2427 "y" : "year" ,
@@ -75,7 +78,7 @@ def __init__(self, expr: Column, partition_by: Sequence[str | Column]) -> None:
7578
7679# NOTE: don't lru_cache this as `ModuleType` isn't hashable
7780def native_to_narwhals_dtype ( # noqa: C901, PLR0912
78- dtype : _NativeDType , version : Version , spark_types : ModuleType
81+ dtype : _NativeDType , version : Version , spark_types : ModuleType , session : SparkSession
7982) -> DType :
8083 dtypes = version .dtypes
8184 if TYPE_CHECKING :
@@ -105,16 +108,14 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
105108 # TODO(marco): cover this
106109 return dtypes .Datetime () # pragma: no cover
107110 if isinstance (dtype , native .TimestampType ):
108- # TODO(marco): is UTC correct, or should we be getting the connection timezone?
109- # https://github.com/narwhals-dev/narwhals/issues/2165
110- return dtypes .Datetime (time_zone = "UTC" )
111+ return dtypes .Datetime (time_zone = fetch_session_time_zone (session ))
111112 if isinstance (dtype , native .DecimalType ):
112113 # TODO(marco): cover this
113114 return dtypes .Decimal () # pragma: no cover
114115 if isinstance (dtype , native .ArrayType ):
115116 return dtypes .List (
116117 inner = native_to_narwhals_dtype (
117- dtype .elementType , version = version , spark_types = spark_types
118+ dtype .elementType , version , spark_types , session
118119 )
119120 )
120121 if isinstance (dtype , native .StructType ):
@@ -123,7 +124,7 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
123124 dtypes .Field (
124125 name = field .name ,
125126 dtype = native_to_narwhals_dtype (
126- field .dataType , version = version , spark_types = spark_types
127+ field .dataType , version , spark_types , session
127128 ),
128129 )
129130 for field in dtype
@@ -134,6 +135,16 @@ def native_to_narwhals_dtype( # noqa: C901, PLR0912
134135 return dtypes .Unknown () # pragma: no cover
135136
136137
138+ @lru_cache (maxsize = 4 )
139+ def fetch_session_time_zone (session : SparkSession ) -> str :
140+ # Timezone can't be changed in PySpark session, so this can be cached.
141+ try :
142+ return session .conf .get ("spark.sql.session.timeZone" ) # type: ignore[attr-defined]
143+ except Exception : # noqa: BLE001
144+ # https://github.com/eakmanrq/sqlframe/issues/406
145+ return "<unknown>"
146+
147+
137148def narwhals_to_native_dtype ( # noqa: C901, PLR0912
138149 dtype : DType | type [DType ], version : Version , spark_types : ModuleType
139150) -> _NativeDType :
0 commit comments