Skip to content

Commit e875209

Browse files
icexellossHyukjinKwon
authored andcommitted
[SPARK-24624][SQL][PYTHON] Support mixture of Python UDF and Scalar Pandas UDF
## What changes were proposed in this pull request? This PR add supports for using mixed Python UDF and Scalar Pandas UDF, in the following two cases: (1) ``` from pyspark.sql.functions import udf, pandas_udf udf('int') def f1(x): return x + 1 pandas_udf('int') def f2(x): return x + 1 df = spark.range(0, 1).toDF('v') \ .withColumn('foo', f1(col('v'))) \ .withColumn('bar', f2(col('v'))) ``` QueryPlan: ``` >>> df.explain(True) == Parsed Logical Plan == 'Project [v#2L, foo#5, f2('v) AS bar#9] +- AnalysisBarrier +- Project [v#2L, f1(v#2L) AS foo#5] +- Project [id#0L AS v#2L] +- Range (0, 1, step=1, splits=Some(4)) == Analyzed Logical Plan == v: bigint, foo: int, bar: int Project [v#2L, foo#5, f2(v#2L) AS bar#9] +- Project [v#2L, f1(v#2L) AS foo#5] +- Project [id#0L AS v#2L] +- Range (0, 1, step=1, splits=Some(4)) == Optimized Logical Plan == Project [id#0L AS v#2L, f1(id#0L) AS foo#5, f2(id#0L) AS bar#9] +- Range (0, 1, step=1, splits=Some(4)) == Physical Plan == *(2) Project [id#0L AS v#2L, pythonUDF0#13 AS foo#5, pythonUDF0#14 AS bar#9] +- ArrowEvalPython [f2(id#0L)], [id#0L, pythonUDF0#13, pythonUDF0#14] +- BatchEvalPython [f1(id#0L)], [id#0L, pythonUDF0#13] +- *(1) Range (0, 1, step=1, splits=4) ``` (2) ``` from pyspark.sql.functions import udf, pandas_udf udf('int') def f1(x): return x + 1 pandas_udf('int') def f2(x): return x + 1 df = spark.range(0, 1).toDF('v') df = df.withColumn('foo', f2(f1(df['v']))) ``` QueryPlan: ``` >>> df.explain(True) == Parsed Logical Plan == Project [v#21L, f2(f1(v#21L)) AS foo#46] +- AnalysisBarrier +- Project [v#21L, f1(f2(v#21L)) AS foo#39] +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32] +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25] +- Project [id#19L AS v#21L] +- Range (0, 1, step=1, splits=Some(4)) == Analyzed Logical Plan == v: bigint, foo: int Project [v#21L, f2(f1(v#21L)) AS foo#46] +- Project [v#21L, f1(f2(v#21L)) AS foo#39] +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#32] +- Project [v#21L, <lambda>(<lambda>(v#21L)) AS foo#25] +- Project [id#19L AS v#21L] +- Range (0, 1, step=1, splits=Some(4)) == Optimized Logical Plan == Project [id#19L AS v#21L, f2(f1(id#19L)) AS foo#46] +- Range (0, 1, step=1, splits=Some(4)) == Physical Plan == *(2) Project [id#19L AS v#21L, pythonUDF0#50 AS foo#46] +- ArrowEvalPython [f2(pythonUDF0#49)], [id#19L, pythonUDF0#49, pythonUDF0#50] +- BatchEvalPython [f1(id#19L)], [id#19L, pythonUDF0#49] +- *(1) Range (0, 1, step=1, splits=4) ``` ## How was this patch tested? New tests are added to BatchEvalPythonExecSuite and ScalarPandasUDFTests Author: Li Jin <[email protected]> Closes apache#21650 from icexelloss/SPARK-24624-mix-udf.
1 parent 6424b14 commit e875209

File tree

4 files changed

+304
-23
lines changed

4 files changed

+304
-23
lines changed

python/pyspark/sql/tests.py

Lines changed: 175 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self):
47634763
'Result vector from pandas_udf was not the required length'):
47644764
df.select(raise_exception(col('id'))).collect()
47654765

4766-
def test_vectorized_udf_mix_udf(self):
4767-
from pyspark.sql.functions import pandas_udf, udf, col
4768-
df = self.spark.range(10)
4769-
row_by_row_udf = udf(lambda x: x, LongType())
4770-
pd_udf = pandas_udf(lambda x: x, LongType())
4771-
with QuietTest(self.sc):
4772-
with self.assertRaisesRegexp(
4773-
Exception,
4774-
'Can not mix vectorized and non-vectorized UDFs'):
4775-
df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
4776-
47774766
def test_vectorized_udf_chained(self):
47784767
from pyspark.sql.functions import pandas_udf, col
47794768
df = self.spark.range(10)
@@ -5060,6 +5049,166 @@ def test_type_annotation(self):
50605049
df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
50615050
self.assertEqual(df.first()[0], 0)
50625051

5052+
def test_mixed_udf(self):
5053+
import pandas as pd
5054+
from pyspark.sql.functions import col, udf, pandas_udf
5055+
5056+
df = self.spark.range(0, 1).toDF('v')
5057+
5058+
# Test mixture of multiple UDFs and Pandas UDFs.
5059+
5060+
@udf('int')
5061+
def f1(x):
5062+
assert type(x) == int
5063+
return x + 1
5064+
5065+
@pandas_udf('int')
5066+
def f2(x):
5067+
assert type(x) == pd.Series
5068+
return x + 10
5069+
5070+
@udf('int')
5071+
def f3(x):
5072+
assert type(x) == int
5073+
return x + 100
5074+
5075+
@pandas_udf('int')
5076+
def f4(x):
5077+
assert type(x) == pd.Series
5078+
return x + 1000
5079+
5080+
# Test single expression with chained UDFs
5081+
df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
5082+
df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
5083+
df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
5084+
df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
5085+
df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
5086+
5087+
expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11)
5088+
expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111)
5089+
expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111)
5090+
expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011)
5091+
expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101)
5092+
5093+
self.assertEquals(expected_chained_1.collect(), df_chained_1.collect())
5094+
self.assertEquals(expected_chained_2.collect(), df_chained_2.collect())
5095+
self.assertEquals(expected_chained_3.collect(), df_chained_3.collect())
5096+
self.assertEquals(expected_chained_4.collect(), df_chained_4.collect())
5097+
self.assertEquals(expected_chained_5.collect(), df_chained_5.collect())
5098+
5099+
# Test multiple mixed UDF expressions in a single projection
5100+
df_multi_1 = df \
5101+
.withColumn('f1', f1(col('v'))) \
5102+
.withColumn('f2', f2(col('v'))) \
5103+
.withColumn('f3', f3(col('v'))) \
5104+
.withColumn('f4', f4(col('v'))) \
5105+
.withColumn('f2_f1', f2(col('f1'))) \
5106+
.withColumn('f3_f1', f3(col('f1'))) \
5107+
.withColumn('f4_f1', f4(col('f1'))) \
5108+
.withColumn('f3_f2', f3(col('f2'))) \
5109+
.withColumn('f4_f2', f4(col('f2'))) \
5110+
.withColumn('f4_f3', f4(col('f3'))) \
5111+
.withColumn('f3_f2_f1', f3(col('f2_f1'))) \
5112+
.withColumn('f4_f2_f1', f4(col('f2_f1'))) \
5113+
.withColumn('f4_f3_f1', f4(col('f3_f1'))) \
5114+
.withColumn('f4_f3_f2', f4(col('f3_f2'))) \
5115+
.withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))
5116+
5117+
# Test mixed udfs in a single expression
5118+
df_multi_2 = df \
5119+
.withColumn('f1', f1(col('v'))) \
5120+
.withColumn('f2', f2(col('v'))) \
5121+
.withColumn('f3', f3(col('v'))) \
5122+
.withColumn('f4', f4(col('v'))) \
5123+
.withColumn('f2_f1', f2(f1(col('v')))) \
5124+
.withColumn('f3_f1', f3(f1(col('v')))) \
5125+
.withColumn('f4_f1', f4(f1(col('v')))) \
5126+
.withColumn('f3_f2', f3(f2(col('v')))) \
5127+
.withColumn('f4_f2', f4(f2(col('v')))) \
5128+
.withColumn('f4_f3', f4(f3(col('v')))) \
5129+
.withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
5130+
.withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
5131+
.withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
5132+
.withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
5133+
.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
5134+
5135+
expected = df \
5136+
.withColumn('f1', df['v'] + 1) \
5137+
.withColumn('f2', df['v'] + 10) \
5138+
.withColumn('f3', df['v'] + 100) \
5139+
.withColumn('f4', df['v'] + 1000) \
5140+
.withColumn('f2_f1', df['v'] + 11) \
5141+
.withColumn('f3_f1', df['v'] + 101) \
5142+
.withColumn('f4_f1', df['v'] + 1001) \
5143+
.withColumn('f3_f2', df['v'] + 110) \
5144+
.withColumn('f4_f2', df['v'] + 1010) \
5145+
.withColumn('f4_f3', df['v'] + 1100) \
5146+
.withColumn('f3_f2_f1', df['v'] + 111) \
5147+
.withColumn('f4_f2_f1', df['v'] + 1011) \
5148+
.withColumn('f4_f3_f1', df['v'] + 1101) \
5149+
.withColumn('f4_f3_f2', df['v'] + 1110) \
5150+
.withColumn('f4_f3_f2_f1', df['v'] + 1111)
5151+
5152+
self.assertEquals(expected.collect(), df_multi_1.collect())
5153+
self.assertEquals(expected.collect(), df_multi_2.collect())
5154+
5155+
def test_mixed_udf_and_sql(self):
5156+
import pandas as pd
5157+
from pyspark.sql import Column
5158+
from pyspark.sql.functions import udf, pandas_udf
5159+
5160+
df = self.spark.range(0, 1).toDF('v')
5161+
5162+
# Test mixture of UDFs, Pandas UDFs and SQL expression.
5163+
5164+
@udf('int')
5165+
def f1(x):
5166+
assert type(x) == int
5167+
return x + 1
5168+
5169+
def f2(x):
5170+
assert type(x) == Column
5171+
return x + 10
5172+
5173+
@pandas_udf('int')
5174+
def f3(x):
5175+
assert type(x) == pd.Series
5176+
return x + 100
5177+
5178+
df1 = df.withColumn('f1', f1(df['v'])) \
5179+
.withColumn('f2', f2(df['v'])) \
5180+
.withColumn('f3', f3(df['v'])) \
5181+
.withColumn('f1_f2', f1(f2(df['v']))) \
5182+
.withColumn('f1_f3', f1(f3(df['v']))) \
5183+
.withColumn('f2_f1', f2(f1(df['v']))) \
5184+
.withColumn('f2_f3', f2(f3(df['v']))) \
5185+
.withColumn('f3_f1', f3(f1(df['v']))) \
5186+
.withColumn('f3_f2', f3(f2(df['v']))) \
5187+
.withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
5188+
.withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
5189+
.withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
5190+
.withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
5191+
.withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
5192+
.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
5193+
5194+
expected = df.withColumn('f1', df['v'] + 1) \
5195+
.withColumn('f2', df['v'] + 10) \
5196+
.withColumn('f3', df['v'] + 100) \
5197+
.withColumn('f1_f2', df['v'] + 11) \
5198+
.withColumn('f1_f3', df['v'] + 101) \
5199+
.withColumn('f2_f1', df['v'] + 11) \
5200+
.withColumn('f2_f3', df['v'] + 110) \
5201+
.withColumn('f3_f1', df['v'] + 101) \
5202+
.withColumn('f3_f2', df['v'] + 110) \
5203+
.withColumn('f1_f2_f3', df['v'] + 111) \
5204+
.withColumn('f1_f3_f2', df['v'] + 111) \
5205+
.withColumn('f2_f1_f3', df['v'] + 111) \
5206+
.withColumn('f2_f3_f1', df['v'] + 111) \
5207+
.withColumn('f3_f1_f2', df['v'] + 111) \
5208+
.withColumn('f3_f2_f1', df['v'] + 111)
5209+
5210+
self.assertEquals(expected.collect(), df1.collect())
5211+
50635212

50645213
@unittest.skipIf(
50655214
not _have_pandas or not _have_pyarrow,
@@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df):
54875636
F.col('temp0.key') == F.col('temp1.key'))
54885637
self.assertEquals(res.count(), 5)
54895638

5639+
def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
5640+
import pandas as pd
5641+
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
5642+
5643+
df = self.spark.range(0, 10).toDF('v1')
5644+
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
5645+
.withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
5646+
5647+
result = df.groupby() \
5648+
.apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
5649+
'sum int',
5650+
PandasUDFType.GROUPED_MAP))
5651+
5652+
self.assertEquals(result.collect()[0]['sum'], 165)
5653+
54905654

54915655
@unittest.skipIf(
54925656
not _have_pandas or not _have_pyarrow,

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.mutable
2121
import scala.collection.mutable.ArrayBuffer
2222

2323
import org.apache.spark.api.python.PythonEvalType
24+
import org.apache.spark.sql.AnalysisException
2425
import org.apache.spark.sql.catalyst.expressions._
2526
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2627
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
@@ -94,36 +95,52 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
9495
*/
9596
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
9697

97-
private def hasPythonUDF(e: Expression): Boolean = {
98+
private type EvalType = Int
99+
private type EvalTypeChecker = EvalType => Boolean
100+
101+
private def hasScalarPythonUDF(e: Expression): Boolean = {
98102
e.find(PythonUDF.isScalarPythonUDF).isDefined
99103
}
100104

101105
private def canEvaluateInPython(e: PythonUDF): Boolean = {
102106
e.children match {
103107
// single PythonUDF child could be chained and evaluated in Python
104-
case Seq(u: PythonUDF) => canEvaluateInPython(u)
108+
case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u)
105109
// Python UDF can't be evaluated directly in JVM
106-
case children => !children.exists(hasPythonUDF)
110+
case children => !children.exists(hasScalarPythonUDF)
107111
}
108112
}
109113

110-
private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
111-
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf)
112-
case e => e.children.flatMap(collectEvaluatableUDF)
114+
private def collectEvaluableUDFsFromExpressions(expressions: Seq[Expression]): Seq[PythonUDF] = {
115+
// Eval type checker is set once when we find the first evaluable UDF and its value
116+
// shouldn't change later.
117+
// Used to check if subsequent UDFs are of the same type as the first UDF. (since we can only
118+
// extract UDFs of the same eval type)
119+
var evalTypeChecker: Option[EvalTypeChecker] = None
120+
121+
def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
122+
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
123+
&& evalTypeChecker.isEmpty =>
124+
evalTypeChecker = Some((otherEvalType: EvalType) => otherEvalType == udf.evalType)
125+
Seq(udf)
126+
case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf)
127+
&& evalTypeChecker.get(udf.evalType) =>
128+
Seq(udf)
129+
case e => e.children.flatMap(collectEvaluableUDFs)
130+
}
131+
132+
expressions.flatMap(collectEvaluableUDFs)
113133
}
114134

115135
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
116-
// AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker
117-
// Therefore we don't need to extract the UDFs
118-
case plan: FlatMapGroupsInPandasExec => plan
119136
case plan: SparkPlan => extract(plan)
120137
}
121138

122139
/**
123140
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
124141
*/
125142
private def extract(plan: SparkPlan): SparkPlan = {
126-
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
143+
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
127144
// ignore the PythonUDF that come from second/third aggregate, which is not used
128145
.filter(udf => udf.references.subsetOf(plan.inputSet))
129146
if (udfs.isEmpty) {
@@ -167,7 +184,8 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
167184
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
168185
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
169186
case _ =>
170-
throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
187+
throw new AnalysisException(
188+
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
171189
}
172190

173191
attributeMap ++= validUdfs.zip(resultAttrs)
@@ -205,7 +223,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
205223
case filter: FilterExec =>
206224
val (candidates, nonDeterministic) =
207225
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
208-
val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
226+
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
209227
if (pushDown.nonEmpty) {
210228
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
211229
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)

sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,10 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
115115
dataType = BooleanType,
116116
pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
117117
udfDeterministic = true)
118+
119+
class MyDummyScalarPandasUDF extends UserDefinedPythonFunction(
120+
name = "dummyScalarPandasUDF",
121+
func = new DummyUDF,
122+
dataType = BooleanType,
123+
pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF,
124+
udfDeterministic = true)

0 commit comments

Comments
 (0)