Skip to content

Commit be08b41

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality
## What changes were proposed in this pull request? This change is a cleanup and consolidation of 3 areas related to Pandas UDFs: 1) `ArrowStreamPandasSerializer` now inherits from `ArrowStreamSerializer` and uses the base class `dump_stream`, `load_stream` to create Arrow reader/writer and send Arrow record batches. `ArrowStreamPandasSerializer` makes the conversions to/from Pandas and converts to Arrow record batch iterators. This change removed duplicated creation of Arrow readers/writers. 2) `createDataFrame` with Arrow now uses `ArrowStreamPandasSerializer` instead of doing its own conversions from Pandas to Arrow and sending record batches through `ArrowStreamSerializer`. 3) Grouped Map UDFs now reuse existing logic in `ArrowStreamPandasSerializer` to send Pandas DataFrame results as a `StructType` instead of separating each column from the DataFrame. This makes the code a little more consistent with the Python worker, but does require that the returned StructType column is flattened out in `FlatMapGroupsInPandasExec` in Scala. ## How was this patch tested? Existing tests and ran tests with pyarrow 0.12.0 Closes apache#24095 from BryanCutler/arrow-refactor-cleanup-UDFs. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent b1857a4 commit be08b41

File tree

5 files changed

+161
-135
lines changed

5 files changed

+161
-135
lines changed

python/pyspark/serializers.py

Lines changed: 119 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -245,92 +245,13 @@ def __repr__(self):
245245
return "ArrowStreamSerializer"
246246

247247

248-
def _create_batch(series, timezone, safecheck, assign_cols_by_name):
248+
class ArrowStreamPandasSerializer(ArrowStreamSerializer):
249249
"""
250-
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
250+
Serializes Pandas.Series as Arrow data with Arrow streaming format.
251251
252-
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
253252
:param timezone: A timezone to respect when handling timestamp values
254-
:return: Arrow RecordBatch
255-
"""
256-
import decimal
257-
from distutils.version import LooseVersion
258-
import pandas as pd
259-
import pyarrow as pa
260-
from pyspark.sql.types import _check_series_convert_timestamps_internal
261-
# Make input conform to [(series1, type1), (series2, type2), ...]
262-
if not isinstance(series, (list, tuple)) or \
263-
(len(series) == 2 and isinstance(series[1], pa.DataType)):
264-
series = [series]
265-
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
266-
267-
def create_array(s, t):
268-
mask = s.isnull()
269-
# Ensure timestamp series are in expected form for Spark internal representation
270-
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
271-
if t is not None and pa.types.is_timestamp(t):
272-
s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
273-
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
274-
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
275-
elif t is not None and pa.types.is_string(t) and sys.version < '3':
276-
# TODO: need decode before converting to Arrow in Python 2
277-
# TODO: don't need as of Arrow 0.9.1
278-
return pa.Array.from_pandas(s.apply(
279-
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
280-
elif t is not None and pa.types.is_decimal(t) and \
281-
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
282-
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
283-
return pa.Array.from_pandas(s.apply(
284-
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
285-
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
286-
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
287-
return pa.Array.from_pandas(s, mask=mask, type=t)
288-
289-
try:
290-
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
291-
except pa.ArrowException as e:
292-
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
293-
"Array (%s). It can be caused by overflows or other unsafe " + \
294-
"conversions warned by Arrow. Arrow safe type check can be " + \
295-
"disabled by using SQL config " + \
296-
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
297-
raise RuntimeError(error_msg % (s.dtype, t), e)
298-
return array
299-
300-
arrs = []
301-
for s, t in series:
302-
if t is not None and pa.types.is_struct(t):
303-
if not isinstance(s, pd.DataFrame):
304-
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
305-
"but got: %s" % str(type(s)))
306-
307-
# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
308-
if len(s) == 0 and len(s.columns) == 0:
309-
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
310-
# Assign result columns by schema name if user labeled with strings
311-
elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
312-
arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
313-
# Assign result columns by position
314-
else:
315-
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
316-
for i, field in enumerate(t)]
317-
318-
struct_arrs, struct_names = zip(*arrs_names)
319-
320-
# TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
321-
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
322-
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
323-
else:
324-
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
325-
else:
326-
arrs.append(create_array(s, t))
327-
328-
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
329-
330-
331-
class ArrowStreamPandasSerializer(Serializer):
332-
"""
333-
Serializes Pandas.Series as Arrow data with Arrow streaming format.
253+
:param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
254+
:param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
334255
"""
335256

336257
def __init__(self, timezone, safecheck, assign_cols_by_name):
@@ -347,39 +268,138 @@ def arrow_to_pandas(self, arrow_column):
347268
s = _check_series_localize_timestamps(s, self._timezone)
348269
return s
349270

271+
def _create_batch(self, series):
272+
"""
273+
Create an Arrow record batch from the given pandas.Series or list of Series,
274+
with optional type.
275+
276+
:param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
277+
:return: Arrow RecordBatch
278+
"""
279+
import decimal
280+
from distutils.version import LooseVersion
281+
import pandas as pd
282+
import pyarrow as pa
283+
from pyspark.sql.types import _check_series_convert_timestamps_internal
284+
# Make input conform to [(series1, type1), (series2, type2), ...]
285+
if not isinstance(series, (list, tuple)) or \
286+
(len(series) == 2 and isinstance(series[1], pa.DataType)):
287+
series = [series]
288+
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
289+
290+
def create_array(s, t):
291+
mask = s.isnull()
292+
# Ensure timestamp series are in expected form for Spark internal representation
293+
# TODO: maybe don't need None check anymore as of Arrow 0.9.1
294+
if t is not None and pa.types.is_timestamp(t):
295+
s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
296+
# TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
297+
return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
298+
elif t is not None and pa.types.is_string(t) and sys.version < '3':
299+
# TODO: need decode before converting to Arrow in Python 2
300+
# TODO: don't need as of Arrow 0.9.1
301+
return pa.Array.from_pandas(s.apply(
302+
lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
303+
elif t is not None and pa.types.is_decimal(t) and \
304+
LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
305+
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
306+
return pa.Array.from_pandas(s.apply(
307+
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
308+
elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
309+
# TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
310+
return pa.Array.from_pandas(s, mask=mask, type=t)
311+
312+
try:
313+
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
314+
except pa.ArrowException as e:
315+
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
316+
"Array (%s). It can be caused by overflows or other unsafe " + \
317+
"conversions warned by Arrow. Arrow safe type check can be " + \
318+
"disabled by using SQL config " + \
319+
"`spark.sql.execution.pandas.arrowSafeTypeConversion`."
320+
raise RuntimeError(error_msg % (s.dtype, t), e)
321+
return array
322+
323+
arrs = []
324+
for s, t in series:
325+
if t is not None and pa.types.is_struct(t):
326+
if not isinstance(s, pd.DataFrame):
327+
raise ValueError("A field of type StructType expects a pandas.DataFrame, "
328+
"but got: %s" % str(type(s)))
329+
330+
# Input partition and result pandas.DataFrame empty, make empty Arrays with struct
331+
if len(s) == 0 and len(s.columns) == 0:
332+
arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
333+
# Assign result columns by schema name if user labeled with strings
334+
elif self._assign_cols_by_name and any(isinstance(name, basestring)
335+
for name in s.columns):
336+
arrs_names = [(create_array(s[field.name], field.type), field.name)
337+
for field in t]
338+
# Assign result columns by position
339+
else:
340+
arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
341+
for i, field in enumerate(t)]
342+
343+
struct_arrs, struct_names = zip(*arrs_names)
344+
345+
# TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
346+
if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
347+
arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
348+
else:
349+
arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
350+
else:
351+
arrs.append(create_array(s, t))
352+
353+
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
354+
350355
def dump_stream(self, iterator, stream):
351356
"""
352357
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
353358
a list of series accompanied by an optional pyarrow type to coerce the data to.
354359
"""
355-
import pyarrow as pa
356-
writer = None
357-
try:
358-
for series in iterator:
359-
batch = _create_batch(series, self._timezone, self._safecheck,
360-
self._assign_cols_by_name)
361-
if writer is None:
362-
write_int(SpecialLengths.START_ARROW_STREAM, stream)
363-
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
364-
writer.write_batch(batch)
365-
finally:
366-
if writer is not None:
367-
writer.close()
360+
batches = (self._create_batch(series) for series in iterator)
361+
super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
368362

369363
def load_stream(self, stream):
370364
"""
371365
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
372366
"""
367+
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
373368
import pyarrow as pa
374-
reader = pa.ipc.open_stream(stream)
375-
376-
for batch in reader:
369+
for batch in batches:
377370
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
378371

379372
def __repr__(self):
380373
return "ArrowStreamPandasSerializer"
381374

382375

376+
class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
377+
"""
378+
Serializer used by Python worker to evaluate Pandas UDFs
379+
"""
380+
381+
def dump_stream(self, iterator, stream):
382+
"""
383+
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
384+
This should be sent after creating the first record batch so in case of an error, it can
385+
be sent back to the JVM before the Arrow stream starts.
386+
"""
387+
388+
def init_stream_yield_batches():
389+
should_write_start_length = True
390+
for series in iterator:
391+
batch = self._create_batch(series)
392+
if should_write_start_length:
393+
write_int(SpecialLengths.START_ARROW_STREAM, stream)
394+
should_write_start_length = False
395+
yield batch
396+
397+
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
398+
399+
def __repr__(self):
400+
return "ArrowStreamPandasUDFSerializer"
401+
402+
383403
class BatchedSerializer(Serializer):
384404

385405
"""

python/pyspark/sql/session.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -530,15 +530,29 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
530530
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
531531
data types will be used to coerce the data in Pandas to Arrow conversion.
532532
"""
533-
from pyspark.serializers import ArrowStreamSerializer, _create_batch
534-
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
533+
from distutils.version import LooseVersion
534+
from pyspark.serializers import ArrowStreamPandasSerializer
535+
from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
535536
from pyspark.sql.utils import require_minimum_pandas_version, \
536537
require_minimum_pyarrow_version
537538

538539
require_minimum_pandas_version()
539540
require_minimum_pyarrow_version()
540541

541542
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
543+
import pyarrow as pa
544+
545+
# Create the Spark schema from list of names passed in with Arrow types
546+
if isinstance(schema, (list, tuple)):
547+
if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
548+
temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
549+
arrow_schema = temp_batch.schema
550+
else:
551+
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
552+
struct = StructType()
553+
for name, field in zip(schema, arrow_schema):
554+
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
555+
schema = struct
542556

543557
# Determine arrow types to coerce data when creating batches
544558
if isinstance(schema, StructType):
@@ -555,32 +569,24 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
555569
step = -(-len(pdf) // self.sparkContext.defaultParallelism) # round int up
556570
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
557571

558-
# Create Arrow record batches
559-
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
560-
col_by_name = True # col by name only applies to StructType columns, can't happen here
561-
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
562-
timezone, safecheck, col_by_name)
563-
for pdf_slice in pdf_slices]
564-
565-
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
566-
if isinstance(schema, (list, tuple)):
567-
struct = from_arrow_schema(batches[0].schema)
568-
for i, name in enumerate(schema):
569-
struct.fields[i].name = name
570-
struct.names[i] = name
571-
schema = struct
572+
# Create list of Arrow (columns, type) for serializer dump_stream
573+
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
574+
for pdf_slice in pdf_slices]
572575

573576
jsqlContext = self._wrapped._jsqlContext
574577

578+
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
579+
col_by_name = True # col by name only applies to StructType columns, can't happen here
580+
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
581+
575582
def reader_func(temp_filename):
576583
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
577584

578585
def create_RDD_server():
579586
return self._jvm.ArrowRDDServer(jsqlContext)
580587

581588
# Create Spark DataFrame from Arrow stream file, using one batch per partition
582-
jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
583-
create_RDD_server)
589+
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
584590
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
585591
df = DataFrame(jdf, self._wrapped)
586592
df._schema = schema

python/pyspark/worker.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from pyspark.rdd import PythonEvalType
3939
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
4040
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
41-
BatchedSerializer, ArrowStreamPandasSerializer
41+
BatchedSerializer, ArrowStreamPandasUDFSerializer
4242
from pyspark.sql.types import to_arrow_type, StructType
4343
from pyspark.util import _get_argspec, fail_on_stopiteration
4444
from pyspark import shuffle
@@ -103,10 +103,7 @@ def verify_result_length(*a):
103103
return lambda *a: (verify_result_length(*a), arrow_return_type)
104104

105105

106-
def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
107-
assign_cols_by_name = runner_conf.get(
108-
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")
109-
assign_cols_by_name = assign_cols_by_name.lower() == "true"
106+
def wrap_grouped_map_pandas_udf(f, return_type, argspec):
110107

111108
def wrapped(key_series, value_series):
112109
import pandas as pd
@@ -125,15 +122,9 @@ def wrapped(key_series, value_series):
125122
"Number of columns of the returned pandas.DataFrame "
126123
"doesn't match specified schema. "
127124
"Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
125+
return result
128126

129-
# Assign result columns by schema name if user labeled with strings, else use position
130-
if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns):
131-
return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type]
132-
else:
133-
return [(result[result.columns[i]], to_arrow_type(field.dataType))
134-
for i, field in enumerate(return_type)]
135-
136-
return wrapped
127+
return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
137128

138129

139130
def wrap_grouped_agg_pandas_udf(f, return_type):
@@ -227,7 +218,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
227218
return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
228219
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
229220
argspec = _get_argspec(row_func) # signature was lost when wrapping it
230-
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
221+
return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
231222
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
232223
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
233224
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -257,12 +248,12 @@ def read_udfs(pickleSer, infile, eval_type):
257248
timezone = runner_conf.get("spark.sql.session.timeZone", None)
258249
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
259250
"false").lower() == 'true'
260-
# NOTE: this is duplicated from wrap_grouped_map_pandas_udf
251+
# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
261252
assign_cols_by_name = runner_conf.get(
262253
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
263254
.lower() == "true"
264255

265-
ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
256+
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name)
266257
else:
267258
ser = BatchedSerializer(PickleSerializer(), 100)
268259

0 commit comments

Comments
 (0)