Skip to content

Commit b91d407

Browse files
Yicong-HuangHyukjinKwon
authored andcommitted
[SPARK-55336][PYTHON] Let createDF use create_batch logic for decoupling
### What changes were proposed in this pull request? This PR duplicates the pandas-to-Arrow batch conversion logic in `ArrowStreamPandasSerializer` to decouple it. - `create_arrow_array_from_pandas()` - converts a pandas Series to Arrow Array - `create_arrow_batch_from_pandas()` - converts a list of (series, spark_type) tuples to Arrow RecordBatch Both `_create_from_pandas_with_arrow` (classic Spark) and `createDataFrame` (Spark Connect) now use these standalone functions directly with `ArrowStreamSerializer`, instead of depending on `ArrowStreamPandasSerializer`. ### Why are the changes needed? For better decoupling. Previously, `createDataFrame` had to instantiate `ArrowStreamPandasSerializer` just to call its `_create_batch` method. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #54111 from Yicong-Huang/SPARK-55336/refactor/factor-out-create-batch-logic. Authored-by: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 455ea6c commit b91d407

File tree

2 files changed

+136
-20
lines changed

2 files changed

+136
-20
lines changed

python/pyspark/sql/connect/session.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from pyspark.sql.connect.readwriter import DataFrameReader
7070
from pyspark.sql.connect.streaming.readwriter import DataStreamReader
7171
from pyspark.sql.connect.streaming.query import StreamingQueryManager
72-
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
72+
from pyspark.sql.pandas.conversion import create_arrow_batch_from_pandas
7373
from pyspark.sql.pandas.types import (
7474
to_arrow_schema,
7575
_deduplicate_field_names,
@@ -621,17 +621,12 @@ def createDataFrame(
621621

622622
safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
623623

624-
ser = ArrowStreamPandasSerializer(
625-
timezone=cast(str, timezone),
626-
safecheck=safecheck == "true",
627-
int_to_decimal_coercion_enabled=False,
628-
prefers_large_types=prefers_large_types,
629-
)
630-
631624
_table = pa.Table.from_batches(
632625
[
633-
ser._create_batch(
626+
create_arrow_batch_from_pandas(
634627
[(c, st) for (_, c), st in zip(data.items(), spark_types)],
628+
timezone=cast(str, timezone),
629+
safecheck=safecheck == "true",
635630
prefers_large_types=prefers_large_types,
636631
)
637632
]

python/pyspark/sql/pandas/conversion.py

Lines changed: 132 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from typing import (
1919
Any,
2020
Callable,
21+
Iterable,
2122
List,
2223
Optional,
2324
Sequence,
25+
Tuple,
2426
Union,
2527
cast,
2628
no_type_check,
@@ -48,13 +50,132 @@
4850

4951
if TYPE_CHECKING:
5052
import numpy as np
53+
import pandas as pd
5154
import pyarrow as pa
5255
from py4j.java_gateway import JavaObject
5356

5457
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
5558
from pyspark.sql import DataFrame
5659

5760

61+
def create_arrow_array_from_pandas(
62+
series: "pd.Series",
63+
spark_type: Optional[DataType],
64+
*,
65+
timezone: Optional[str] = None,
66+
safecheck: bool = False,
67+
prefers_large_types: bool = False,
68+
) -> "pa.Array":
69+
"""
70+
Create an Arrow Array from the given pandas.Series and Spark type.
71+
72+
Parameters
73+
----------
74+
series : pandas.Series
75+
A single series
76+
spark_type : DataType, optional
77+
The Spark return type. If None, pyarrow's inferred type will be used.
78+
timezone : str, optional
79+
The timezone to use for timestamp conversions.
80+
safecheck : bool, optional
81+
Whether to enable safe type checking during conversion.
82+
prefers_large_types : bool, optional
83+
Whether to prefer large Arrow types (e.g., large_string instead of string).
84+
85+
Returns
86+
-------
87+
pyarrow.Array
88+
"""
89+
import pyarrow as pa
90+
import pandas as pd
91+
from pyspark.sql.pandas.types import to_arrow_type, _create_converter_from_pandas
92+
93+
if isinstance(series.dtype, pd.CategoricalDtype):
94+
series = series.astype(series.dtype.categories.dtype)
95+
96+
# Derive arrow_type from spark_type
97+
arrow_type = (
98+
to_arrow_type(spark_type, timezone=timezone, prefers_large_types=prefers_large_types)
99+
if spark_type is not None
100+
else None
101+
)
102+
103+
if spark_type is not None:
104+
conv = _create_converter_from_pandas(
105+
spark_type,
106+
timezone=timezone,
107+
error_on_duplicated_field_names=False,
108+
)
109+
series = conv(series)
110+
111+
if hasattr(series.array, "__arrow_array__"):
112+
mask = None
113+
else:
114+
mask = series.isnull()
115+
try:
116+
return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=safecheck)
117+
except TypeError as e:
118+
error_msg = (
119+
"Exception thrown when converting pandas.Series (%s) "
120+
"with name '%s' to Arrow Array (%s)."
121+
)
122+
raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e
123+
except ValueError as e:
124+
error_msg = (
125+
"Exception thrown when converting pandas.Series (%s) "
126+
"with name '%s' to Arrow Array (%s)."
127+
)
128+
if safecheck:
129+
error_msg = error_msg + (
130+
" It can be caused by overflows or other "
131+
"unsafe conversions warned by Arrow. Arrow safe type check "
132+
"can be disabled by using SQL config "
133+
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
134+
)
135+
raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e
136+
137+
138+
def create_arrow_batch_from_pandas(
139+
series_with_types: Iterable[Tuple["pd.Series", Optional[DataType]]],
140+
*,
141+
timezone: Optional[str] = None,
142+
safecheck: bool = False,
143+
prefers_large_types: bool = False,
144+
) -> "pa.RecordBatch":
145+
"""
146+
Create an Arrow record batch from the given iterable of (series, spark_type) tuples.
147+
148+
Parameters
149+
----------
150+
series_with_types : iterable
151+
Iterable of (series, spark_type) tuples.
152+
timezone : str, optional
153+
The timezone to use for timestamp conversions.
154+
safecheck : bool, optional
155+
Whether to enable safe type checking during conversion.
156+
prefers_large_types : bool, optional
157+
Whether to prefer large Arrow types (e.g., large_string instead of string).
158+
159+
Returns
160+
-------
161+
pyarrow.RecordBatch
162+
Arrow RecordBatch
163+
"""
164+
import pyarrow as pa
165+
166+
arrs = [
167+
create_arrow_array_from_pandas(
168+
s,
169+
spark_type,
170+
timezone=timezone,
171+
safecheck=safecheck,
172+
prefers_large_types=prefers_large_types,
173+
)
174+
for s, spark_type in series_with_types
175+
]
176+
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
177+
178+
58179
def _convert_arrow_table_to_pandas(
59180
arrow_table: "pa.Table",
60181
schema: "StructType",
@@ -807,7 +928,7 @@ def _create_from_pandas_with_arrow(
807928

808929
assert isinstance(self, SparkSession)
809930

810-
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
931+
from pyspark.sql.pandas.serializers import ArrowStreamSerializer
811932
from pyspark.sql.types import TimestampType
812933
from pyspark.sql.pandas.types import (
813934
from_arrow_type,
@@ -877,20 +998,20 @@ def _create_from_pandas_with_arrow(
877998
step = step if step > 0 else len(pdf)
878999
pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step))
8791000

880-
# Create list of (columns, spark_type) for serializer dump_stream
881-
arrow_data = [
882-
[(c, t) for (_, c), t in zip(pdf_slice.items(), spark_types)]
1001+
# Create Arrow batches directly using the standalone function
1002+
arrow_batches = [
1003+
create_arrow_batch_from_pandas(
1004+
[(c, t) for (_, c), t in zip(pdf_slice.items(), spark_types)],
1005+
timezone=timezone,
1006+
safecheck=safecheck,
1007+
prefers_large_types=prefers_large_var_types,
1008+
)
8831009
for pdf_slice in pdf_slices
8841010
]
8851011

8861012
jsparkSession = self._jsparkSession
8871013

888-
ser = ArrowStreamPandasSerializer(
889-
timezone=timezone,
890-
safecheck=safecheck,
891-
int_to_decimal_coercion_enabled=False,
892-
prefers_large_types=prefers_large_var_types,
893-
)
1014+
ser = ArrowStreamSerializer()
8941015

8951016
@no_type_check
8961017
def reader_func(temp_filename):
@@ -901,7 +1022,7 @@ def create_iter_server():
9011022
return self._jvm.ArrowIteratorServer()
9021023

9031024
# Create Spark DataFrame from Arrow stream file, using one batch per partition
904-
jiter = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_iter_server)
1025+
jiter = self._sc._serialize_to_jvm(arrow_batches, ser, reader_func, create_iter_server)
9051026
assert self._jvm is not None
9061027
jdf = self._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession)
9071028
df = DataFrame(jdf, self)

0 commit comments

Comments
 (0)