Skip to content

Commit 7858e53

Browse files
viiryaBryanCutler
authored andcommitted
[SPARK-28323][SQL][PYTHON] PythonUDF should be able to use in join condition
## What changes were proposed in this pull request? There is a bug in `ExtractPythonUDFs` that produces wrong result attributes. It causes a failure when using `PythonUDF`s among multiple child plans, e.g., join. An example is using `PythonUDF`s in join condition. ```python >>> left = spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)]) >>> right = spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)]) >>> f = udf(lambda a: a, IntegerType()) >>> df = left.join(right, [f("a") == f("b"), left.a1 == right.b1]) >>> df.collect() 19/07/10 12:20:49 ERROR Executor: Exception in task 5.0 in stage 0.0 (TID 5) java.lang.ArrayIndexOutOfBoundsException: 1 at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.genericGet(rows.scala:201) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.getAs(rows.scala:35) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt(rows.scala:36) at org.apache.spark.sql.catalyst.expressions.BaseGenericInternalRow.isNullAt$(rows.scala:36) at org.apache.spark.sql.catalyst.expressions.GenericInternalRow.isNullAt(rows.scala:195) at org.apache.spark.sql.catalyst.expressions.JoinedRow.isNullAt(JoinedRow.scala:70) ... ``` ## How was this patch tested? Added test. Closes apache#25091 from viirya/SPARK-28323. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 579edf4 commit 7858e53

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

python/pyspark/sql/tests/test_udf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def test_udf_in_join_condition(self):
197197
left = self.spark.createDataFrame([Row(a=1)])
198198
right = self.spark.createDataFrame([Row(b=1)])
199199
f = udf(lambda a, b: a == b, BooleanType())
200+
# The udf uses attributes from both sides of join, so it is pulled out as Filter +
201+
# Cross join.
200202
df = left.join(right, f("a", "b"))
201203
with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
202204
df.collect()
@@ -243,6 +245,14 @@ def runWithJoinType(join_type, type_string):
243245
runWithJoinType("leftanti", "LeftAnti")
244246
runWithJoinType("leftsemi", "LeftSemi")
245247

248+
def test_udf_as_join_condition(self):
249+
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
250+
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
251+
f = udf(lambda a: a, IntegerType())
252+
253+
df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
254+
self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
255+
246256
def test_udf_without_arguments(self):
247257
self.spark.catalog.registerFunction("foo", lambda: "bar")
248258
[row] = self.spark.sql("SELECT foo()").collect()

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
179179
validUdfs.forall(PythonUDF.isScalarPythonUDF),
180180
"Can only extract scalar vectorized udf or sql batch udf")
181181

182-
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
182+
val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) =>
183183
AttributeReference(s"pythonUDF$i", u.dataType)()
184184
}
185185

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
2828
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
2929
import org.apache.spark.sql.execution.{BinaryExecNode, SortExec}
3030
import org.apache.spark.sql.execution.joins._
31+
import org.apache.spark.sql.execution.python.BatchEvalPythonExec
3132
import org.apache.spark.sql.internal.SQLConf
3233
import org.apache.spark.sql.test.SharedSQLContext
3334
import org.apache.spark.sql.types.StructType
@@ -969,4 +970,28 @@ class JoinSuite extends QueryTest with SharedSQLContext {
969970
Seq(Row(0.0d, 0.0/0.0)))))
970971
}
971972
}
973+
974+
test("SPARK-28323: PythonUDF should be able to use in join condition") {
975+
import IntegratedUDFTestUtils._
976+
977+
assume(shouldTestPythonUDFs)
978+
979+
val pythonTestUDF = TestPythonUDF(name = "udf")
980+
981+
val left = Seq((1, 2), (2, 3)).toDF("a", "b")
982+
val right = Seq((1, 2), (3, 4)).toDF("c", "d")
983+
val df = left.join(right, pythonTestUDF($"a") === pythonTestUDF($"c"))
984+
985+
val joinNode = df.queryExecution.executedPlan.find(_.isInstanceOf[BroadcastHashJoinExec])
986+
assert(joinNode.isDefined)
987+
988+
// There are two PythonUDFs which use attribute from left and right of join, individually.
989+
// So two PythonUDFs should be evaluated before the join operator, at left and right side.
990+
val pythonEvals = joinNode.get.collect {
991+
case p: BatchEvalPythonExec => p
992+
}
993+
assert(pythonEvals.size == 2)
994+
995+
checkAnswer(df, Row(1, 2, 1, 2) :: Nil)
996+
}
972997
}

0 commit comments

Comments
 (0)