1515# limitations under the License.
1616#
1717from pyspark .sql .connect .utils import check_dependencies
18- from pyspark .sql .utils import is_timestamp_ntz_preferred
1918
2019check_dependencies (__name__ )
2120
3736 cast ,
3837 overload ,
3938 Iterable ,
39+ Mapping ,
4040 TYPE_CHECKING ,
4141 ClassVar ,
4242)
@@ -407,7 +407,10 @@ def clearProgressHandlers(self) -> None:
407407 clearProgressHandlers .__doc__ = PySparkSession .clearProgressHandlers .__doc__
408408
409409 def _inferSchemaFromList (
410- self , data : Iterable [Any ], names : Optional [List [str ]] = None
410+ self ,
411+ data : Iterable [Any ],
412+ names : Optional [List [str ]],
413+ configs : Mapping [str , Optional [str ]],
411414 ) -> StructType :
412415 """
413416 Infer schema from list of Row, dict, or tuple.
@@ -422,12 +425,12 @@ def _inferSchemaFromList(
422425 infer_dict_as_struct ,
423426 infer_array_from_first_element ,
424427 infer_map_from_first_pair ,
425- prefer_timestamp_ntz ,
426- ) = self . _client . get_configs (
427- "spark.sql.pyspark.inferNestedDictAsStruct.enabled" ,
428- "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled" ,
429- "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled" ,
430- "spark.sql.timestampType" ,
428+ prefer_timestamp ,
429+ ) = (
430+ configs [ "spark.sql.pyspark.inferNestedDictAsStruct.enabled" ] ,
431+ configs [ "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled" ] ,
432+ configs [ "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled" ] ,
433+ configs [ "spark.sql.timestampType" ] ,
431434 )
432435 return functools .reduce (
433436 _merge_type ,
@@ -438,7 +441,7 @@ def _inferSchemaFromList(
438441 infer_dict_as_struct = (infer_dict_as_struct == "true" ),
439442 infer_array_from_first_element = (infer_array_from_first_element == "true" ),
440443 infer_map_from_first_pair = (infer_map_from_first_pair == "true" ),
441- prefer_timestamp_ntz = (prefer_timestamp_ntz == "TIMESTAMP_NTZ" ),
444+ prefer_timestamp_ntz = (prefer_timestamp == "TIMESTAMP_NTZ" ),
442445 )
443446 for row in data
444447 ),
@@ -508,8 +511,21 @@ def createDataFrame(
508511 messageParameters = {},
509512 )
510513
514+ # Get all related configs in a batch
515+ configs = self ._client .get_config_dict (
516+ "spark.sql.timestampType" ,
517+ "spark.sql.session.timeZone" ,
518+ "spark.sql.session.localRelationCacheThreshold" ,
519+ "spark.sql.execution.pandas.convertToArrowArraySafely" ,
520+ "spark.sql.execution.pandas.inferPandasDictAsMap" ,
521+ "spark.sql.pyspark.inferNestedDictAsStruct.enabled" ,
522+ "spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled" ,
523+ "spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled" ,
524+ )
525+ timezone = configs ["spark.sql.session.timeZone" ]
526+ prefer_timestamp = configs ["spark.sql.timestampType" ]
527+
511528 _table : Optional [pa .Table ] = None
512- timezone : Optional [str ] = None
513529
514530 if isinstance (data , pd .DataFrame ):
515531 # Logic was borrowed from `_create_from_pandas_with_arrow` in
@@ -519,8 +535,7 @@ def createDataFrame(
519535 if schema is None :
520536 _cols = [str (x ) if not isinstance (x , str ) else x for x in data .columns ]
521537 infer_pandas_dict_as_map = (
522- str (self .conf .get ("spark.sql.execution.pandas.inferPandasDictAsMap" )).lower ()
523- == "true"
538+ configs ["spark.sql.execution.pandas.inferPandasDictAsMap" ] == "true"
524539 )
525540 if infer_pandas_dict_as_map :
526541 struct = StructType ()
@@ -572,9 +587,7 @@ def createDataFrame(
572587 ]
573588 arrow_types = [to_arrow_type (dt ) if dt is not None else None for dt in spark_types ]
574589
575- timezone , safecheck = self ._client .get_configs (
576- "spark.sql.session.timeZone" , "spark.sql.execution.pandas.convertToArrowArraySafely"
577- )
590+ safecheck = configs ["spark.sql.execution.pandas.convertToArrowArraySafely" ]
578591
579592 ser = ArrowStreamPandasSerializer (cast (str , timezone ), safecheck == "true" )
580593
@@ -596,10 +609,6 @@ def createDataFrame(
596609 ).cast (arrow_schema )
597610
598611 elif isinstance (data , pa .Table ):
599- prefer_timestamp_ntz = is_timestamp_ntz_preferred ()
600-
601- (timezone ,) = self ._client .get_configs ("spark.sql.session.timeZone" )
602-
603612 # If no schema supplied by user then get the names of columns only
604613 if schema is None :
605614 _cols = data .column_names
@@ -609,7 +618,9 @@ def createDataFrame(
609618 _num_cols = len (_cols )
610619
611620 if not isinstance (schema , StructType ):
612- schema = from_arrow_schema (data .schema , prefer_timestamp_ntz = prefer_timestamp_ntz )
621+ schema = from_arrow_schema (
622+ data .schema , prefer_timestamp_ntz = prefer_timestamp == "TIMESTAMP_NTZ"
623+ )
613624
614625 _table = (
615626 _check_arrow_table_timestamps_localize (data , schema , True , timezone )
@@ -671,7 +682,7 @@ def createDataFrame(
671682 if not isinstance (_schema , StructType ):
672683 _schema = StructType ().add ("value" , _schema )
673684 else :
674- _schema = self ._inferSchemaFromList (_data , _cols )
685+ _schema = self ._inferSchemaFromList (_data , _cols , configs )
675686
676687 if _cols is not None and cast (int , _num_cols ) < len (_cols ):
677688 _num_cols = len (_cols )
@@ -706,9 +717,9 @@ def createDataFrame(
706717 else :
707718 local_relation = LocalRelation (_table )
708719
709- cache_threshold = self . _client . get_configs ( "spark.sql.session.localRelationCacheThreshold" )
720+ cache_threshold = configs [ "spark.sql.session.localRelationCacheThreshold" ]
710721 plan : LogicalPlan = local_relation
711- if cache_threshold [ 0 ] is not None and int (cache_threshold [ 0 ] ) <= _table .nbytes :
722+ if cache_threshold is not None and int (cache_threshold ) <= _table .nbytes :
712723 plan = CachedLocalRelation (self ._cache_local_relation (local_relation ))
713724
714725 df = DataFrame (plan , self )
0 commit comments