Skip to content

Commit ad82505

Browse files
committed
[SPARK-50324][PYTHON][CONNECT] Make createDataFrame trigger Config RPC at most once
### What changes were proposed in this pull request? Get all configs in batch ### Why are the changes needed? there are too many related configs in `createDataFrame`, they are fetched one by one (or group by group) in different branches: 1, it is possible no Config RPC is triggered, e.g. in this branch: https://github.com/apache/spark/blob/26330355836f5b2dad9b7bd4c72d9830c7ce6788/python/pyspark/sql/connect/session.py#L502-L509 2, multiple Config RPCs for different configs, e.g. in this branch: https://github.com/apache/spark/blob/26330355836f5b2dad9b7bd4c72d9830c7ce6788/python/pyspark/sql/connect/session.py#L599-L601 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48856 from zhengruifeng/lazy_config. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 77055b8 commit ad82505

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

python/pyspark/sql/connect/client/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Dict,
4444
Set,
4545
NoReturn,
46+
Mapping,
4647
cast,
4748
TYPE_CHECKING,
4849
Type,
@@ -1576,6 +1577,10 @@ def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
15761577
configs = dict(self.config(op).pairs)
15771578
return tuple(configs.get(key) for key in keys)
15781579

1580+
def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]:
1581+
op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
1582+
return dict(self.config(op).pairs)
1583+
15791584
def get_config_with_defaults(
15801585
self, *pairs: Tuple[str, Optional[str]]
15811586
) -> Tuple[Optional[str], ...]:

python/pyspark/sql/connect/session.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
#
1717
from pyspark.sql.connect.utils import check_dependencies
18-
from pyspark.sql.utils import is_timestamp_ntz_preferred
1918

2019
check_dependencies(__name__)
2120

@@ -37,6 +36,7 @@
3736
cast,
3837
overload,
3938
Iterable,
39+
Mapping,
4040
TYPE_CHECKING,
4141
ClassVar,
4242
)
@@ -407,7 +407,10 @@ def clearProgressHandlers(self) -> None:
407407
clearProgressHandlers.__doc__ = PySparkSession.clearProgressHandlers.__doc__
408408

409409
def _inferSchemaFromList(
410-
self, data: Iterable[Any], names: Optional[List[str]] = None
410+
self,
411+
data: Iterable[Any],
412+
names: Optional[List[str]],
413+
configs: Mapping[str, Optional[str]],
411414
) -> StructType:
412415
"""
413416
Infer schema from list of Row, dict, or tuple.
@@ -422,12 +425,12 @@ def _inferSchemaFromList(
422425
infer_dict_as_struct,
423426
infer_array_from_first_element,
424427
infer_map_from_first_pair,
425-
prefer_timestamp_ntz,
426-
) = self._client.get_configs(
427-
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
428-
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
429-
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
430-
"spark.sql.timestampType",
428+
prefer_timestamp,
429+
) = (
430+
configs["spark.sql.pyspark.inferNestedDictAsStruct.enabled"],
431+
configs["spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled"],
432+
configs["spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled"],
433+
configs["spark.sql.timestampType"],
431434
)
432435
return functools.reduce(
433436
_merge_type,
@@ -438,7 +441,7 @@ def _inferSchemaFromList(
438441
infer_dict_as_struct=(infer_dict_as_struct == "true"),
439442
infer_array_from_first_element=(infer_array_from_first_element == "true"),
440443
infer_map_from_first_pair=(infer_map_from_first_pair == "true"),
441-
prefer_timestamp_ntz=(prefer_timestamp_ntz == "TIMESTAMP_NTZ"),
444+
prefer_timestamp_ntz=(prefer_timestamp == "TIMESTAMP_NTZ"),
442445
)
443446
for row in data
444447
),
@@ -508,8 +511,21 @@ def createDataFrame(
508511
messageParameters={},
509512
)
510513

514+
# Get all related configs in a batch
515+
configs = self._client.get_config_dict(
516+
"spark.sql.timestampType",
517+
"spark.sql.session.timeZone",
518+
"spark.sql.session.localRelationCacheThreshold",
519+
"spark.sql.execution.pandas.convertToArrowArraySafely",
520+
"spark.sql.execution.pandas.inferPandasDictAsMap",
521+
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
522+
"spark.sql.pyspark.legacy.inferArrayTypeFromFirstElement.enabled",
523+
"spark.sql.pyspark.legacy.inferMapTypeFromFirstPair.enabled",
524+
)
525+
timezone = configs["spark.sql.session.timeZone"]
526+
prefer_timestamp = configs["spark.sql.timestampType"]
527+
511528
_table: Optional[pa.Table] = None
512-
timezone: Optional[str] = None
513529

514530
if isinstance(data, pd.DataFrame):
515531
# Logic was borrowed from `_create_from_pandas_with_arrow` in
@@ -519,8 +535,7 @@ def createDataFrame(
519535
if schema is None:
520536
_cols = [str(x) if not isinstance(x, str) else x for x in data.columns]
521537
infer_pandas_dict_as_map = (
522-
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower()
523-
== "true"
538+
configs["spark.sql.execution.pandas.inferPandasDictAsMap"] == "true"
524539
)
525540
if infer_pandas_dict_as_map:
526541
struct = StructType()
@@ -572,9 +587,7 @@ def createDataFrame(
572587
]
573588
arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types]
574589

575-
timezone, safecheck = self._client.get_configs(
576-
"spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely"
577-
)
590+
safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
578591

579592
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")
580593

@@ -596,10 +609,6 @@ def createDataFrame(
596609
).cast(arrow_schema)
597610

598611
elif isinstance(data, pa.Table):
599-
prefer_timestamp_ntz = is_timestamp_ntz_preferred()
600-
601-
(timezone,) = self._client.get_configs("spark.sql.session.timeZone")
602-
603612
# If no schema supplied by user then get the names of columns only
604613
if schema is None:
605614
_cols = data.column_names
@@ -609,7 +618,9 @@ def createDataFrame(
609618
_num_cols = len(_cols)
610619

611620
if not isinstance(schema, StructType):
612-
schema = from_arrow_schema(data.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)
621+
schema = from_arrow_schema(
622+
data.schema, prefer_timestamp_ntz=prefer_timestamp == "TIMESTAMP_NTZ"
623+
)
613624

614625
_table = (
615626
_check_arrow_table_timestamps_localize(data, schema, True, timezone)
@@ -671,7 +682,7 @@ def createDataFrame(
671682
if not isinstance(_schema, StructType):
672683
_schema = StructType().add("value", _schema)
673684
else:
674-
_schema = self._inferSchemaFromList(_data, _cols)
685+
_schema = self._inferSchemaFromList(_data, _cols, configs)
675686

676687
if _cols is not None and cast(int, _num_cols) < len(_cols):
677688
_num_cols = len(_cols)
@@ -706,9 +717,9 @@ def createDataFrame(
706717
else:
707718
local_relation = LocalRelation(_table)
708719

709-
cache_threshold = self._client.get_configs("spark.sql.session.localRelationCacheThreshold")
720+
cache_threshold = configs["spark.sql.session.localRelationCacheThreshold"]
710721
plan: LogicalPlan = local_relation
711-
if cache_threshold[0] is not None and int(cache_threshold[0]) <= _table.nbytes:
722+
if cache_threshold is not None and int(cache_threshold) <= _table.nbytes:
712723
plan = CachedLocalRelation(self._cache_local_relation(local_relation))
713724

714725
df = DataFrame(plan, self)

0 commit comments

Comments
 (0)