Skip to content

Commit 594be7a

Browse files
ueshinBryanCutler
authored andcommitted
[SPARK-27240][PYTHON] Use pandas DataFrame for struct type argument in Scalar Pandas UDF.
## What changes were proposed in this pull request? Now that we support returning pandas DataFrame for struct type in Scalar Pandas UDF. If we chain another Pandas UDF after the Scalar Pandas UDF returning pandas DataFrame, the argument of the chained UDF will be pandas DataFrame, but currently we don't support pandas DataFrame as an argument of Scalar Pandas UDF. That means there is an inconsistency between the chained UDF and the single UDF. We should support taking pandas DataFrame for struct type argument in Scalar Pandas UDF to be consistent. Currently pyarrow >=0.11 is supported. ## How was this patch tested? Modified and added some tests. Closes apache#24177 from ueshin/issues/SPARK-27240/structtype_argument. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 8bc304f commit 594be7a

File tree

4 files changed

+72
-6
lines changed

4 files changed

+72
-6
lines changed

python/pyspark/serializers.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,10 @@ def __init__(self, timezone, safecheck, assign_cols_by_name):
260260
self._safecheck = safecheck
261261
self._assign_cols_by_name = assign_cols_by_name
262262

263-
def arrow_to_pandas(self, arrow_column):
264-
from pyspark.sql.types import from_arrow_type, \
265-
_arrow_column_to_pandas, _check_series_localize_timestamps
263+
def arrow_to_pandas(self, arrow_column, data_type):
264+
from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps
266265

267-
s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type))
266+
s = _arrow_column_to_pandas(arrow_column, data_type)
268267
s = _check_series_localize_timestamps(s, self._timezone)
269268
return s
270269

@@ -366,8 +365,10 @@ def load_stream(self, stream):
366365
"""
367366
batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
368367
import pyarrow as pa
368+
from pyspark.sql.types import from_arrow_type
369369
for batch in batches:
370-
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
370+
yield [self.arrow_to_pandas(c, from_arrow_type(c.type))
371+
for c in pa.Table.from_batches([batch]).itercolumns()]
371372

372373
def __repr__(self):
373374
return "ArrowStreamPandasSerializer"
@@ -378,6 +379,24 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
378379
Serializer used by Python worker to evaluate Pandas UDFs
379380
"""
380381

382+
def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
383+
super(ArrowStreamPandasUDFSerializer, self) \
384+
.__init__(timezone, safecheck, assign_cols_by_name)
385+
self._df_for_struct = df_for_struct
386+
387+
def arrow_to_pandas(self, arrow_column, data_type):
388+
from pyspark.sql.types import StructType, \
389+
_arrow_column_to_pandas, _check_dataframe_localize_timestamps
390+
391+
if self._df_for_struct and type(data_type) == StructType:
392+
import pandas as pd
393+
series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name)
394+
for column, field in zip(arrow_column.flatten(), data_type)]
395+
s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone)
396+
else:
397+
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type)
398+
return s
399+
381400
def dump_stream(self, iterator, stream):
382401
"""
383402
Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.

python/pyspark/sql/tests/test_pandas_udf_scalar.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def test_vectorized_udf_null_array(self):
270270

271271
def test_vectorized_udf_struct_type(self):
272272
import pandas as pd
273+
import pyarrow as pa
273274

274275
df = self.spark.range(10)
275276
return_type = StructType([
@@ -291,6 +292,18 @@ def func(id):
291292
actual = df.select(g(col('id')).alias('struct')).collect()
292293
self.assertEqual(expected, actual)
293294

295+
struct_f = pandas_udf(lambda x: x, return_type)
296+
actual = df.select(struct_f(struct(col('id'), col('id').cast('string').alias('str'))))
297+
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
298+
with QuietTest(self.sc):
299+
from py4j.protocol import Py4JJavaError
300+
with self.assertRaisesRegexp(
301+
Py4JJavaError,
302+
'Unsupported type in conversion from Arrow'):
303+
self.assertEqual(expected, actual.collect())
304+
else:
305+
self.assertEqual(expected, actual.collect())
306+
294307
def test_vectorized_udf_struct_complex(self):
295308
import pandas as pd
296309

@@ -363,6 +376,26 @@ def test_vectorized_udf_chained(self):
363376
res = df.select(g(f(col('id'))))
364377
self.assertEquals(df.collect(), res.collect())
365378

379+
def test_vectorized_udf_chained_struct_type(self):
380+
import pandas as pd
381+
382+
df = self.spark.range(10)
383+
return_type = StructType([
384+
StructField('id', LongType()),
385+
StructField('str', StringType())])
386+
387+
@pandas_udf(return_type)
388+
def f(id):
389+
return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
390+
391+
g = pandas_udf(lambda x: x, return_type)
392+
393+
expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
394+
.alias('struct')).collect()
395+
396+
actual = df.select(g(f(col('id'))).alias('struct')).collect()
397+
self.assertEqual(expected, actual)
398+
366399
def test_vectorized_udf_wrong_return_type(self):
367400
with QuietTest(self.sc):
368401
with self.assertRaisesRegexp(

python/pyspark/sql/types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,16 @@ def from_arrow_type(at):
16741674
if types.is_timestamp(at.value_type):
16751675
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
16761676
spark_type = ArrayType(from_arrow_type(at.value_type))
1677+
elif types.is_struct(at):
1678+
# TODO: remove version check once minimum pyarrow version is 0.10.0
1679+
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
1680+
raise TypeError("Unsupported type in conversion from Arrow: " + str(at) +
1681+
"\nPlease install pyarrow >= 0.10.0 for StructType support.")
1682+
if any(types.is_struct(field.type) for field in at):
1683+
raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
1684+
return StructType(
1685+
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
1686+
for field in at])
16771687
else:
16781688
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
16791689
return spark_type

python/pyspark/worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ def read_udfs(pickleSer, infile, eval_type):
253253
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
254254
.lower() == "true"
255255

256-
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name)
256+
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
257+
# pandas Series. See SPARK-27240.
258+
df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
259+
ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
260+
df_for_struct)
257261
else:
258262
ser = BatchedSerializer(PickleSerializer(), 100)
259263

0 commit comments

Comments
 (0)