Skip to content

Commit 8198ea5

Browse files
icexellosscloud-fan
authored andcommitted
[SPARK-24721][SQL] Exclude Python UDFs filters in FileSourceStrategy
## What changes were proposed in this pull request? The PR excludes Python UDFs filters in FileSourceStrategy so that they don't ExtractPythonUDF rule to throw exception. It doesn't make sense to pass Python UDF filters in FileSourceStrategy anyway because they cannot be used as push down filters. ## How was this patch tested? Add a new regression test Closes apache#22104 from icexelloss/SPARK-24721-udf-filter. Authored-by: Li Jin <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent dac099d commit 8198ea5

File tree

11 files changed

+164
-19
lines changed

11 files changed

+164
-19
lines changed

python/pyspark/sql/tests.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,16 @@
6868
# If Arrow version requirement is not satisfied, skip related tests.
6969
_pyarrow_requirement_message = _exception_message(e)
7070

71+
_test_not_compiled_message = None
72+
try:
73+
from pyspark.sql.utils import require_test_compiled
74+
require_test_compiled()
75+
except Exception as e:
76+
_test_not_compiled_message = _exception_message(e)
77+
7178
_have_pandas = _pandas_requirement_message is None
7279
_have_pyarrow = _pyarrow_requirement_message is None
80+
_test_compiled = _test_not_compiled_message is None
7381

7482
from pyspark import SparkContext
7583
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
@@ -3367,6 +3375,47 @@ def test_ignore_column_of_all_nulls(self):
33673375
finally:
33683376
shutil.rmtree(path)
33693377

3378+
# SPARK-24721
3379+
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
3380+
def test_datasource_with_udf(self):
3381+
from pyspark.sql.functions import udf, lit, col
3382+
3383+
path = tempfile.mkdtemp()
3384+
shutil.rmtree(path)
3385+
3386+
try:
3387+
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
3388+
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
3389+
datasource_df = self.spark.read \
3390+
.format("org.apache.spark.sql.sources.SimpleScanSource") \
3391+
.option('from', 0).option('to', 1).load().toDF('i')
3392+
datasource_v2_df = self.spark.read \
3393+
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
3394+
.load().toDF('i', 'j')
3395+
3396+
c1 = udf(lambda x: x + 1, 'int')(lit(1))
3397+
c2 = udf(lambda x: x + 1, 'int')(col('i'))
3398+
3399+
f1 = udf(lambda x: False, 'boolean')(lit(1))
3400+
f2 = udf(lambda x: False, 'boolean')(col('i'))
3401+
3402+
for df in [filesource_df, datasource_df, datasource_v2_df]:
3403+
result = df.withColumn('c', c1)
3404+
expected = df.withColumn('c', lit(2))
3405+
self.assertEquals(expected.collect(), result.collect())
3406+
3407+
for df in [filesource_df, datasource_df, datasource_v2_df]:
3408+
result = df.withColumn('c', c2)
3409+
expected = df.withColumn('c', col('i') + 1)
3410+
self.assertEquals(expected.collect(), result.collect())
3411+
3412+
for df in [filesource_df, datasource_df, datasource_v2_df]:
3413+
for f in [f1, f2]:
3414+
result = df.filter(f)
3415+
self.assertEquals(0, result.count())
3416+
finally:
3417+
shutil.rmtree(path)
3418+
33703419
def test_repr_behaviors(self):
33713420
import re
33723421
pattern = re.compile(r'^ *\|', re.MULTILINE)
@@ -5269,6 +5318,51 @@ def f3(x):
52695318

52705319
self.assertEquals(expected.collect(), df1.collect())
52715320

5321+
# SPARK-24721
5322+
@unittest.skipIf(not _test_compiled, _test_not_compiled_message)
5323+
def test_datasource_with_udf(self):
5324+
# Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
5325+
# This needs to a separate test because Arrow dependency is optional
5326+
import pandas as pd
5327+
import numpy as np
5328+
from pyspark.sql.functions import pandas_udf, lit, col
5329+
5330+
path = tempfile.mkdtemp()
5331+
shutil.rmtree(path)
5332+
5333+
try:
5334+
self.spark.range(1).write.mode("overwrite").format('csv').save(path)
5335+
filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
5336+
datasource_df = self.spark.read \
5337+
.format("org.apache.spark.sql.sources.SimpleScanSource") \
5338+
.option('from', 0).option('to', 1).load().toDF('i')
5339+
datasource_v2_df = self.spark.read \
5340+
.format("org.apache.spark.sql.sources.v2.SimpleDataSourceV2") \
5341+
.load().toDF('i', 'j')
5342+
5343+
c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
5344+
c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))
5345+
5346+
f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
5347+
f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))
5348+
5349+
for df in [filesource_df, datasource_df, datasource_v2_df]:
5350+
result = df.withColumn('c', c1)
5351+
expected = df.withColumn('c', lit(2))
5352+
self.assertEquals(expected.collect(), result.collect())
5353+
5354+
for df in [filesource_df, datasource_df, datasource_v2_df]:
5355+
result = df.withColumn('c', c2)
5356+
expected = df.withColumn('c', col('i') + 1)
5357+
self.assertEquals(expected.collect(), result.collect())
5358+
5359+
for df in [filesource_df, datasource_df, datasource_v2_df]:
5360+
for f in [f1, f2]:
5361+
result = df.filter(f)
5362+
self.assertEquals(0, result.count())
5363+
finally:
5364+
shutil.rmtree(path)
5365+
52725366

52735367
@unittest.skipIf(
52745368
not _have_pandas or not _have_pyarrow,

python/pyspark/sql/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,25 @@ def require_minimum_pyarrow_version():
152152
"your version was %s." % (minimum_pyarrow_version, pyarrow.__version__))
153153

154154

155+
def require_test_compiled():
156+
""" Raise Exception if test classes are not compiled
157+
"""
158+
import os
159+
import glob
160+
try:
161+
spark_home = os.environ['SPARK_HOME']
162+
except KeyError:
163+
raise RuntimeError('SPARK_HOME is not defined in environment')
164+
165+
test_class_path = os.path.join(
166+
spark_home, 'sql', 'core', 'target', '*', 'test-classes')
167+
paths = glob.glob(test_class_path)
168+
169+
if len(paths) == 0:
170+
raise RuntimeError(
171+
"%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)
172+
173+
155174
class ForeachBatchFunction(object):
156175
"""
157176
This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
8989

9090
/** A sequence of rules that will be applied in order to the physical plan before execution. */
9191
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
92-
python.ExtractPythonUDFs,
9392
PlanSubqueries(sparkSession),
9493
EnsureRequirements(sparkSession.sessionState.conf),
9594
CollapseCodegenStages(sparkSession.sessionState.conf),

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.SessionCatalog
2222
import org.apache.spark.sql.catalyst.optimizer.Optimizer
2323
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
2424
import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning
25-
import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
25+
import org.apache.spark.sql.execution.python.{ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
2626

2727
class SparkOptimizer(
2828
catalog: SessionCatalog,
@@ -31,7 +31,8 @@ class SparkOptimizer(
3131

3232
override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
3333
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
34-
Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
34+
Batch("Extract Python UDFs", Once,
35+
Seq(ExtractPythonUDFFromAggregate, ExtractPythonUDFs): _*) :+
3536
Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+
3637
Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning)) ++
3738
postHocOptimizationBatches :+

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class SparkPlanner(
3636
override def strategies: Seq[Strategy] =
3737
experimentalMethods.extraStrategies ++
3838
extraPlanningStrategies ++ (
39+
PythonEvals ::
3940
DataSourceV2Strategy ::
4041
FileSourceStrategy ::
4142
DataSourceStrategy(conf) ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableS
3232
import org.apache.spark.sql.execution.command._
3333
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3434
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
35+
import org.apache.spark.sql.execution.python._
3536
import org.apache.spark.sql.execution.streaming._
3637
import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
3738
import org.apache.spark.sql.internal.SQLConf
@@ -517,6 +518,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
517518
}
518519
}
519520

521+
/**
522+
* Strategy to convert EvalPython logical operator to physical operator.
523+
*/
524+
object PythonEvals extends Strategy {
525+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
526+
case ArrowEvalPython(udfs, output, child) =>
527+
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
528+
case BatchEvalPython(udfs, output, child) =>
529+
BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
530+
case _ =>
531+
Nil
532+
}
533+
}
534+
520535
object BasicOperators extends Strategy {
521536
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
522537
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.TaskContext
2323
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
26+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
2627
import org.apache.spark.sql.execution.SparkPlan
2728
import org.apache.spark.sql.execution.arrow.ArrowUtils
2829
import org.apache.spark.sql.types.StructType
@@ -57,7 +58,13 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
5758
}
5859

5960
/**
60-
* A physical plan that evaluates a [[PythonUDF]],
61+
* A logical plan that evaluates a [[PythonUDF]].
62+
*/
63+
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
64+
extends UnaryNode
65+
66+
/**
67+
* A physical plan that evaluates a [[PythonUDF]].
6168
*/
6269
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
6370
extends EvalPythonExec(udfs, output, child) {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@ import org.apache.spark.TaskContext
2525
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions._
28+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
2829
import org.apache.spark.sql.execution.SparkPlan
2930
import org.apache.spark.sql.types.{StructField, StructType}
3031

32+
/**
33+
* A logical plan that evaluates a [[PythonUDF]]
34+
*/
35+
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
36+
extends UnaryNode
37+
3138
/**
3239
* A physical plan that evaluates a [[PythonUDF]]
3340
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ import org.apache.spark.api.python.PythonEvalType
2424
import org.apache.spark.sql.AnalysisException
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
27-
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
27+
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
2828
import org.apache.spark.sql.catalyst.rules.Rule
29-
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
3029

3130

3231
/**
@@ -93,7 +92,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
9392
* This has the limitation that the input to the Python UDF is not allowed include attributes from
9493
* multiple child operators.
9594
*/
96-
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
95+
object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
9796

9897
private type EvalType = Int
9998
private type EvalTypeChecker = EvalType => Boolean
@@ -132,14 +131,14 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
132131
expressions.flatMap(collectEvaluableUDFs)
133132
}
134133

135-
def apply(plan: SparkPlan): SparkPlan = plan transformUp {
136-
case plan: SparkPlan => extract(plan)
134+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
135+
case plan: LogicalPlan => extract(plan)
137136
}
138137

139138
/**
140139
* Extract all the PythonUDFs from the current operator and evaluate them before the operator.
141140
*/
142-
private def extract(plan: SparkPlan): SparkPlan = {
141+
private def extract(plan: LogicalPlan): LogicalPlan = {
143142
val udfs = collectEvaluableUDFsFromExpressions(plan.expressions)
144143
// ignore the PythonUDF that come from second/third aggregate, which is not used
145144
.filter(udf => udf.references.subsetOf(plan.inputSet))
@@ -151,7 +150,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
151150
val prunedChildren = plan.children.map { child =>
152151
val allNeededOutput = inputsForPlan.intersect(child.outputSet).toSeq
153152
if (allNeededOutput.length != child.output.length) {
154-
ProjectExec(allNeededOutput, child)
153+
Project(allNeededOutput, child)
155154
} else {
156155
child
157156
}
@@ -180,9 +179,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
180179
_.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF
181180
) match {
182181
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
183-
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
182+
ArrowEvalPython(vectorizedUdfs, child.output ++ resultAttrs, child)
184183
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
185-
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
184+
BatchEvalPython(plainUdfs, child.output ++ resultAttrs, child)
186185
case _ =>
187186
throw new AnalysisException(
188187
"Expected either Scalar Pandas UDFs or Batched UDFs but got both")
@@ -209,7 +208,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
209208
val newPlan = extract(rewritten)
210209
if (newPlan.output != plan.output) {
211210
// Trim away the new UDF value if it was only used for filtering or something.
212-
ProjectExec(plan.output, newPlan)
211+
Project(plan.output, newPlan)
213212
} else {
214213
newPlan
215214
}
@@ -218,15 +217,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
218217

219218
// Split the original FilterExec to two FilterExecs. Only push down the first few predicates
220219
// that are all deterministic.
221-
private def trySplitFilter(plan: SparkPlan): SparkPlan = {
220+
private def trySplitFilter(plan: LogicalPlan): LogicalPlan = {
222221
plan match {
223-
case filter: FilterExec =>
222+
case filter: Filter =>
224223
val (candidates, nonDeterministic) =
225224
splitConjunctivePredicates(filter.condition).partition(_.deterministic)
226225
val (pushDown, rest) = candidates.partition(!hasScalarPythonUDF(_))
227226
if (pushDown.nonEmpty) {
228-
val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
229-
FilterExec((rest ++ nonDeterministic).reduceLeft(And), newChild)
227+
val newChild = Filter(pushDown.reduceLeft(And), filter.child)
228+
Filter((rest ++ nonDeterministic).reduceLeft(And), newChild)
230229
} else {
231230
filter
232231
}

sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ import org.apache.spark.sql.types._
2828

2929
class DefaultSource extends SimpleScanSource
3030

31+
// This class is used by pyspark tests. If this class is modified/moved, make sure pyspark
32+
// tests still pass.
3133
class SimpleScanSource extends RelationProvider {
3234
override def createRelation(
3335
sqlContext: SQLContext,

0 commit comments

Comments
 (0)