1111 from types import ModuleType
1212
1313 import pyspark .sql .types as pyspark_types
14+ import sqlframe .base .types as sqlframe_types
1415 from pyspark .sql import Column
16+ from typing_extensions import TypeAlias
1517
1618 from narwhals ._spark_like .dataframe import SparkLikeLazyFrame
1719 from narwhals ._spark_like .expr import SparkLikeExpr
1820 from narwhals .dtypes import DType
1921 from narwhals .utils import Version
2022
23+ _NativeDType : TypeAlias = "pyspark_types.DataType | sqlframe_types.DataType"
24+
2125
2226# NOTE: don't lru_cache this as `ModuleType` isn't hashable
2327def native_to_narwhals_dtype (
24- dtype : pyspark_types .DataType ,
25- version : Version ,
26- spark_types : ModuleType ,
28+ dtype : _NativeDType , version : Version , spark_types : ModuleType
2729) -> DType : # pragma: no cover
2830 dtypes = import_dtypes_module (version = version )
31+ if TYPE_CHECKING :
32+ native = pyspark_types
33+ else :
34+ native = spark_types
2935
30- if isinstance (dtype , spark_types .DoubleType ):
36+ if isinstance (dtype , native .DoubleType ):
3137 return dtypes .Float64 ()
32- if isinstance (dtype , spark_types .FloatType ):
38+ if isinstance (dtype , native .FloatType ):
3339 return dtypes .Float32 ()
34- if isinstance (dtype , spark_types .LongType ):
40+ if isinstance (dtype , native .LongType ):
3541 return dtypes .Int64 ()
36- if isinstance (dtype , spark_types .IntegerType ):
42+ if isinstance (dtype , native .IntegerType ):
3743 return dtypes .Int32 ()
38- if isinstance (dtype , spark_types .ShortType ):
44+ if isinstance (dtype , native .ShortType ):
3945 return dtypes .Int16 ()
40- if isinstance (dtype , spark_types .ByteType ):
46+ if isinstance (dtype , native .ByteType ):
4147 return dtypes .Int8 ()
42- if isinstance (
43- dtype , (spark_types .StringType , spark_types .VarcharType , spark_types .CharType )
44- ):
48+ if isinstance (dtype , (native .StringType , native .VarcharType , native .CharType )):
4549 return dtypes .String ()
46- if isinstance (dtype , spark_types .BooleanType ):
50+ if isinstance (dtype , native .BooleanType ):
4751 return dtypes .Boolean ()
48- if isinstance (dtype , spark_types .DateType ):
52+ if isinstance (dtype , native .DateType ):
4953 return dtypes .Date ()
50- if isinstance (dtype , spark_types .TimestampNTZType ):
54+ if isinstance (dtype , native .TimestampNTZType ):
5155 return dtypes .Datetime ()
52- if isinstance (dtype , spark_types .TimestampType ):
56+ if isinstance (dtype , native .TimestampType ):
5357 return dtypes .Datetime (time_zone = "UTC" )
54- if isinstance (dtype , spark_types .DecimalType ):
58+ if isinstance (dtype , native .DecimalType ):
5559 return dtypes .Decimal ()
56- if isinstance (dtype , spark_types .ArrayType ):
60+ if isinstance (dtype , native .ArrayType ):
5761 return dtypes .List (
5862 inner = native_to_narwhals_dtype (
5963 dtype .elementType , version = version , spark_types = spark_types
6064 )
6165 )
62- if isinstance (dtype , spark_types .StructType ):
66+ if isinstance (dtype , native .StructType ):
6367 return dtypes .Struct (
6468 fields = [
6569 dtypes .Field (
@@ -78,48 +82,50 @@ def narwhals_to_native_dtype(
7882 dtype : DType | type [DType ], version : Version , spark_types : ModuleType
7983) -> pyspark_types .DataType :
8084 dtypes = import_dtypes_module (version )
85+ if TYPE_CHECKING :
86+ native = pyspark_types
87+ else :
88+ native = spark_types
8189
8290 if isinstance_or_issubclass (dtype , dtypes .Float64 ):
83- return spark_types .DoubleType ()
91+ return native .DoubleType ()
8492 if isinstance_or_issubclass (dtype , dtypes .Float32 ):
85- return spark_types .FloatType ()
93+ return native .FloatType ()
8694 if isinstance_or_issubclass (dtype , dtypes .Int64 ):
87- return spark_types .LongType ()
95+ return native .LongType ()
8896 if isinstance_or_issubclass (dtype , dtypes .Int32 ):
89- return spark_types .IntegerType ()
97+ return native .IntegerType ()
9098 if isinstance_or_issubclass (dtype , dtypes .Int16 ):
91- return spark_types .ShortType ()
99+ return native .ShortType ()
92100 if isinstance_or_issubclass (dtype , dtypes .Int8 ):
93- return spark_types .ByteType ()
101+ return native .ByteType ()
94102 if isinstance_or_issubclass (dtype , dtypes .String ):
95- return spark_types .StringType ()
103+ return native .StringType ()
96104 if isinstance_or_issubclass (dtype , dtypes .Boolean ):
97- return spark_types .BooleanType ()
105+ return native .BooleanType ()
98106 if isinstance_or_issubclass (dtype , dtypes .Date ):
99- return spark_types .DateType ()
107+ return native .DateType ()
100108 if isinstance_or_issubclass (dtype , dtypes .Datetime ):
101109 dt_time_zone = dtype .time_zone
102110 if dt_time_zone is None :
103- return spark_types .TimestampNTZType ()
111+ return native .TimestampNTZType ()
104112 if dt_time_zone != "UTC" : # pragma: no cover
105113 msg = f"Only UTC time zone is supported for PySpark, got: { dt_time_zone } "
106114 raise ValueError (msg )
107- return spark_types .TimestampType ()
115+ return native .TimestampType ()
108116 if isinstance_or_issubclass (dtype , (dtypes .List , dtypes .Array )):
109- return spark_types .ArrayType (
117+ return native .ArrayType (
110118 elementType = narwhals_to_native_dtype (
111- dtype .inner , version = version , spark_types = spark_types
119+ dtype .inner , version = version , spark_types = native
112120 )
113121 )
114122 if isinstance_or_issubclass (dtype , dtypes .Struct ): # pragma: no cover
115- return spark_types .StructType (
123+ return native .StructType (
116124 fields = [
117- spark_types .StructField (
125+ native .StructField (
118126 name = field .name ,
119127 dataType = narwhals_to_native_dtype (
120- field .dtype ,
121- version = version ,
122- spark_types = spark_types ,
128+ field .dtype , version = version , spark_types = native
123129 ),
124130 )
125131 for field in dtype .fields
@@ -147,7 +153,7 @@ def narwhals_to_native_dtype(
147153def evaluate_exprs (
148154 df : SparkLikeLazyFrame , / , * exprs : SparkLikeExpr
149155) -> list [tuple [str , Column ]]:
150- native_results : list [tuple [str , list [ Column ] ]] = []
156+ native_results : list [tuple [str , Column ]] = []
151157
152158 for expr in exprs :
153159 native_series_list = expr ._call (df )
0 commit comments