Skip to content

Commit 27fc536

Browse files
BryanCutlercloud-fan
authored andcommitted
[SPARK-21190][PYSPARK] Python Vectorized UDFs
This PR adds vectorized UDFs to the Python API **Proposed API** Introduce a flag to turn on vectorization for a defined UDF, for example: ``` pandas_udf(DoubleType()) def plus(a, b) return a + b ``` or ``` plus = pandas_udf(lambda a, b: a + b, DoubleType()) ``` Usage is the same as normal UDFs 0-parameter UDFs pandas_udf functions can declare an optional `**kwargs` and when evaluated, will contain a key "size" that will give the required length of the output. For example: ``` pandas_udf(LongType()) def f0(**kwargs): return pd.Series(1).repeat(kwargs["size"]) df.select(f0()) ``` Added new unit tests in pyspark.sql that are enabled if pyarrow and Pandas are available. - [x] Fix support for promoted types with null values - [ ] Discuss 0-param UDF API (use of kwargs) - [x] Add tests for chained UDFs - [ ] Discuss behavior when pyarrow not installed / enabled - [ ] Cleanup pydoc and add user docs Author: Bryan Cutler <[email protected]> Author: Takuya UESHIN <[email protected]> Closes apache#18659 from BryanCutler/arrow-vectorized-udfs-SPARK-21404.
1 parent 8f130ad commit 27fc536

File tree

13 files changed

+666
-173
lines changed

13 files changed

+666
-173
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,23 @@ private[spark] case class PythonFunction(
8383
*/
8484
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
8585

86+
/**
87+
* Enumerate the type of command that will be sent to the Python worker
88+
*/
89+
private[spark] object PythonEvalType {
90+
val NON_UDF = 0
91+
val SQL_BATCHED_UDF = 1
92+
val SQL_PANDAS_UDF = 2
93+
}
94+
8695
private[spark] object PythonRunner {
8796
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
8897
new PythonRunner(
89-
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
98+
Seq(ChainedPythonFunctions(Seq(func))),
99+
bufferSize,
100+
reuse_worker,
101+
PythonEvalType.NON_UDF,
102+
Array(Array(0)))
90103
}
91104
}
92105

@@ -100,7 +113,7 @@ private[spark] class PythonRunner(
100113
funcs: Seq[ChainedPythonFunctions],
101114
bufferSize: Int,
102115
reuse_worker: Boolean,
103-
isUDF: Boolean,
116+
evalType: Int,
104117
argOffsets: Array[Array[Int]])
105118
extends Logging {
106119

@@ -309,8 +322,8 @@ private[spark] class PythonRunner(
309322
}
310323
dataOut.flush()
311324
// Serialized command:
312-
if (isUDF) {
313-
dataOut.writeInt(1)
325+
dataOut.writeInt(evalType)
326+
if (evalType != PythonEvalType.NON_UDF) {
314327
dataOut.writeInt(funcs.length)
315328
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
316329
dataOut.writeInt(offsets.length)
@@ -324,7 +337,6 @@ private[spark] class PythonRunner(
324337
}
325338
}
326339
} else {
327-
dataOut.writeInt(0)
328340
val command = funcs.head.funcs.head.command
329341
dataOut.writeInt(command.length)
330342
dataOut.write(command)

python/pyspark/serializers.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ class SpecialLengths(object):
8181
NULL = -5
8282

8383

84+
class PythonEvalType(object):
85+
NON_UDF = 0
86+
SQL_BATCHED_UDF = 1
87+
SQL_PANDAS_UDF = 2
88+
89+
8490
class Serializer(object):
8591

8692
def dump_stream(self, iterator, stream):
@@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer):
187193
Serializes an Arrow stream.
188194
"""
189195

190-
def dumps(self, obj):
191-
raise NotImplementedError
196+
def dumps(self, batch):
197+
import pyarrow as pa
198+
import io
199+
sink = io.BytesIO()
200+
writer = pa.RecordBatchFileWriter(sink, batch.schema)
201+
writer.write_batch(batch)
202+
writer.close()
203+
return sink.getvalue()
192204

193205
def loads(self, obj):
194206
import pyarrow as pa
@@ -199,6 +211,55 @@ def __repr__(self):
199211
return "ArrowSerializer"
200212

201213

214+
class ArrowPandasSerializer(ArrowSerializer):
215+
"""
216+
Serializes Pandas.Series as Arrow data.
217+
"""
218+
219+
def __init__(self):
220+
super(ArrowPandasSerializer, self).__init__()
221+
222+
def dumps(self, series):
223+
"""
224+
Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
225+
a list of series accompanied by an optional pyarrow type to coerce the data to.
226+
"""
227+
import pyarrow as pa
228+
# Make input conform to [(series1, type1), (series2, type2), ...]
229+
if not isinstance(series, (list, tuple)) or \
230+
(len(series) == 2 and isinstance(series[1], pa.DataType)):
231+
series = [series]
232+
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
233+
234+
# If a nullable integer series has been promoted to floating point with NaNs, need to cast
235+
# NOTE: this is not necessary with Arrow >= 0.7
236+
def cast_series(s, t):
237+
if t is None or s.dtype == t.to_pandas_dtype():
238+
return s
239+
else:
240+
return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
241+
242+
arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
243+
batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
244+
return super(ArrowPandasSerializer, self).dumps(batch)
245+
246+
def loads(self, obj):
247+
"""
248+
Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
249+
followed by a dictionary containing length of the loaded batches.
250+
"""
251+
import pyarrow as pa
252+
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
253+
batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
254+
# NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
255+
num_rows = sum((batch.num_rows for batch in batches))
256+
table = pa.Table.from_batches(batches)
257+
return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]
258+
259+
def __repr__(self):
260+
return "ArrowPandasSerializer"
261+
262+
202263
class BatchedSerializer(Serializer):
203264

204265
"""

python/pyspark/sql/functions.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,7 +2044,7 @@ class UserDefinedFunction(object):
20442044
20452045
.. versionadded:: 1.3
20462046
"""
2047-
def __init__(self, func, returnType, name=None):
2047+
def __init__(self, func, returnType, name=None, vectorized=False):
20482048
if not callable(func):
20492049
raise TypeError(
20502050
"Not a function or callable (__call__ is not defined): "
@@ -2058,6 +2058,7 @@ def __init__(self, func, returnType, name=None):
20582058
self._name = name or (
20592059
func.__name__ if hasattr(func, '__name__')
20602060
else func.__class__.__name__)
2061+
self._vectorized = vectorized
20612062

20622063
@property
20632064
def returnType(self):
@@ -2089,7 +2090,7 @@ def _create_judf(self):
20892090
wrapped_func = _wrap_function(sc, self.func, self.returnType)
20902091
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
20912092
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
2092-
self._name, wrapped_func, jdt)
2093+
self._name, wrapped_func, jdt, self._vectorized)
20932094
return judf
20942095

20952096
def __call__(self, *cols):
@@ -2123,6 +2124,22 @@ def wrapper(*args):
21232124
return wrapper
21242125

21252126

2127+
def _create_udf(f, returnType, vectorized):
2128+
2129+
def _udf(f, returnType=StringType(), vectorized=vectorized):
2130+
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
2131+
return udf_obj._wrapped()
2132+
2133+
# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
2134+
if f is None or isinstance(f, (str, DataType)):
2135+
# If DataType has been passed as a positional argument
2136+
# for decorator use it as a returnType
2137+
return_type = f or returnType
2138+
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
2139+
else:
2140+
return _udf(f=f, returnType=returnType, vectorized=vectorized)
2141+
2142+
21262143
@since(1.3)
21272144
def udf(f=None, returnType=StringType()):
21282145
"""Creates a :class:`Column` expression representing a user defined function (UDF).
@@ -2154,18 +2171,26 @@ def udf(f=None, returnType=StringType()):
21542171
| 8| JOHN DOE| 22|
21552172
+----------+--------------+------------+
21562173
"""
2157-
def _udf(f, returnType=StringType()):
2158-
udf_obj = UserDefinedFunction(f, returnType)
2159-
return udf_obj._wrapped()
2174+
return _create_udf(f, returnType=returnType, vectorized=False)
21602175

2161-
# decorator @udf, @udf() or @udf(dataType())
2162-
if f is None or isinstance(f, (str, DataType)):
2163-
# If DataType has been passed as a positional argument
2164-
# for decorator use it as a returnType
2165-
return_type = f or returnType
2166-
return functools.partial(_udf, returnType=return_type)
2176+
2177+
@since(2.3)
2178+
def pandas_udf(f=None, returnType=StringType()):
2179+
"""
2180+
Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
2181+
`Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
2182+
2183+
:param f: python function if used as a standalone function
2184+
:param returnType: a :class:`pyspark.sql.types.DataType` object
2185+
2186+
# TODO: doctest
2187+
"""
2188+
import inspect
2189+
# If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
2190+
if inspect.getargspec(f).keywords is None:
2191+
return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
21672192
else:
2168-
return _udf(f=f, returnType=returnType)
2193+
return _create_udf(f, returnType=returnType, vectorized=True)
21692194

21702195

21712196
blacklist = ['map', 'since', 'ignore_unicode_prefix']

0 commit comments

Comments
 (0)