Skip to content

Commit a52e348

Browse files
committed
Improve Python UDTF arrow serializer performance
1 parent 208a7ee commit a52e348

File tree

5 files changed

+115
-214
lines changed

5 files changed

+115
-214
lines changed

python/pyspark/errors/error-conditions.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@
980980
},
981981
"UDTF_ARROW_TYPE_CAST_ERROR": {
982982
"message": [
983-
"Cannot convert the output value of the column '<col_name>' with type '<col_type>' to the specified return type of the column: '<arrow_type>'. Please check if the data types match and try again."
983+
"Cannot convert the output value of the input '<data>' with type '<schema>' to the specified return type of the column: '<arrow_schema>'. Please check if the data types match and try again."
984984
]
985985
},
986986
"UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD": {

python/pyspark/sql/conversion.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool)
345345
if isinstance(item, dict):
346346
for i, col in enumerate(column_names):
347347
pylist[i].append(column_convs[i](item.get(col)))
348+
elif item is None:
349+
for i, col in enumerate(column_names):
350+
pylist[i].append(None)
348351
else:
349352
if len(item) != len(column_names):
350353
raise PySparkValueError(

python/pyspark/sql/pandas/serializers.py

Lines changed: 16 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,12 @@ def wrap_and_init_stream():
161161
assert isinstance(batch, pa.RecordBatch)
162162

163163
# Wrap the root struct
164-
struct = pa.StructArray.from_arrays(
165-
batch.columns, fields=pa.struct(list(batch.schema))
166-
)
164+
if len(batch.columns) == 0:
165+
struct = pa.array([{}] * batch.num_rows)
166+
else:
167+
struct = pa.StructArray.from_arrays(
168+
batch.columns, fields=pa.struct(list(batch.schema))
169+
)
167170
batch = pa.RecordBatch.from_arrays([struct], ["_0"])
168171

169172
# Write the first record batch with initialization.
@@ -175,6 +178,16 @@ def wrap_and_init_stream():
175178
return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream)
176179

177180

181+
class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer):
182+
"""
183+
Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch
184+
for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`.
185+
"""
186+
187+
def load_stream(self, stream):
188+
return super(ArrowStreamUDFSerializer, self).load_stream(stream)
189+
190+
178191
class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
179192
"""
180193
Serializes pyarrow.RecordBatch data with Arrow streaming format.
@@ -566,151 +579,6 @@ def __repr__(self):
566579
return "ArrowStreamPandasUDFSerializer"
567580

568581

569-
class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
570-
"""
571-
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
572-
"""
573-
574-
def __init__(self, timezone, safecheck):
575-
super(ArrowStreamPandasUDTFSerializer, self).__init__(
576-
timezone=timezone,
577-
safecheck=safecheck,
578-
# The output pandas DataFrame's columns are unnamed.
579-
assign_cols_by_name=False,
580-
# Set to 'False' to avoid converting struct type inputs into a pandas DataFrame.
581-
df_for_struct=False,
582-
# Defines how struct type inputs are converted. If set to "row", struct type inputs
583-
# are converted into Rows. Without this setting, a struct type input would be treated
584-
# as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1),
585-
# if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1}
586-
# if struct_in_pandas="row", it becomes Row(name="Alice", age=1)
587-
struct_in_pandas="row",
588-
# When dealing with array type inputs, Arrow converts them into numpy.ndarrays.
589-
# To ensure consistency across regular and arrow-optimized UDTFs, we further
590-
# convert these numpy.ndarrays into Python lists.
591-
ndarray_as_list=True,
592-
# Enables explicit casting for mismatched return types of Arrow Python UDTFs.
593-
arrow_cast=True,
594-
)
595-
self._converter_map = dict()
596-
597-
def _create_batch(self, series):
598-
"""
599-
Create an Arrow record batch from the given pandas.Series pandas.DataFrame
600-
or list of Series or DataFrame, with optional type.
601-
602-
Parameters
603-
----------
604-
series : pandas.Series or pandas.DataFrame or list
605-
A single series or dataframe, list of series or dataframe,
606-
or list of (series or dataframe, arrow_type)
607-
608-
Returns
609-
-------
610-
pyarrow.RecordBatch
611-
Arrow RecordBatch
612-
"""
613-
import pandas as pd
614-
import pyarrow as pa
615-
616-
# Make input conform to [(series1, type1), (series2, type2), ...]
617-
if not isinstance(series, (list, tuple)) or (
618-
len(series) == 2 and isinstance(series[1], pa.DataType)
619-
):
620-
series = [series]
621-
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
622-
623-
arrs = []
624-
for s, t in series:
625-
if not isinstance(s, pd.DataFrame):
626-
raise PySparkValueError(
627-
"Output of an arrow-optimized Python UDTFs expects "
628-
f"a pandas.DataFrame but got: {type(s)}"
629-
)
630-
631-
arrs.append(self._create_struct_array(s, t))
632-
633-
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
634-
635-
def _get_or_create_converter_from_pandas(self, dt):
636-
if dt not in self._converter_map:
637-
conv = _create_converter_from_pandas(
638-
dt,
639-
timezone=self._timezone,
640-
error_on_duplicated_field_names=False,
641-
ignore_unexpected_complex_type_values=True,
642-
)
643-
self._converter_map[dt] = conv
644-
return self._converter_map[dt]
645-
646-
def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
647-
"""
648-
Override the `_create_array` method in the superclass to create an Arrow Array
649-
from a given pandas.Series and an arrow type. The difference here is that we always
650-
use arrow cast when creating the arrow array. Also, the error messages are specific
651-
to arrow-optimized Python UDTFs.
652-
653-
Parameters
654-
----------
655-
series : pandas.Series
656-
A single series
657-
arrow_type : pyarrow.DataType, optional
658-
If None, pyarrow's inferred type will be used
659-
spark_type : DataType, optional
660-
If None, spark type converted from arrow_type will be used
661-
arrow_cast: bool, optional
662-
Whether to apply Arrow casting when the user-specified return type mismatches the
663-
actual return values.
664-
665-
Returns
666-
-------
667-
pyarrow.Array
668-
"""
669-
import pyarrow as pa
670-
import pandas as pd
671-
672-
if isinstance(series.dtype, pd.CategoricalDtype):
673-
series = series.astype(series.dtypes.categories.dtype)
674-
675-
if arrow_type is not None:
676-
dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
677-
conv = self._get_or_create_converter_from_pandas(dt)
678-
series = conv(series)
679-
680-
if hasattr(series.array, "__arrow_array__"):
681-
mask = None
682-
else:
683-
mask = series.isnull()
684-
685-
try:
686-
try:
687-
return pa.Array.from_pandas(
688-
series, mask=mask, type=arrow_type, safe=self._safecheck
689-
)
690-
except pa.lib.ArrowException:
691-
if arrow_cast:
692-
return pa.Array.from_pandas(series, mask=mask).cast(
693-
target_type=arrow_type, safe=self._safecheck
694-
)
695-
else:
696-
raise
697-
except pa.lib.ArrowException:
698-
# Display the most user-friendly error messages instead of showing
699-
# arrow's error message. This also works better with Spark Connect
700-
# where the exception messages are by default truncated.
701-
raise PySparkRuntimeError(
702-
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
703-
messageParameters={
704-
"col_name": series.name,
705-
"col_type": str(series.dtype),
706-
"arrow_type": arrow_type,
707-
},
708-
) from None
709-
710-
def __repr__(self):
711-
return "ArrowStreamPandasUDTFSerializer"
712-
713-
714582
class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
715583
"""
716584
Serializes pyarrow.RecordBatch data with Arrow streaming format.

0 commit comments

Comments
 (0)