Skip to content

Commit c26b092

Browse files
maryannxuegatorsmile
authored andcommitted
[SPARK-24891][SQL] Fix HandleNullInputsForUDF rule
## What changes were proposed in this pull request? The HandleNullInputsForUDF would always add a new `If` node every time it is applied. That would cause a difference between the same plan being analyzed once and being analyzed twice (or more), thus raising issues like plan not matched in the cache manager. The solution is to mark the arguments as null-checked, which is to add a "KnownNotNull" node above those arguments, when adding the UDF under an `If` node, because clearly the UDF will not be called when any of those arguments is null. ## How was this patch tested? Add new tests under sql/UDFSuite and AnalysisSuite. Author: maryannxue <[email protected]> Closes apache#21851 from maryannxue/spark-24891.
1 parent 15fff79 commit c26b092

File tree

4 files changed

+94
-10
lines changed

4 files changed

+94
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.OuterScopes
3030
import org.apache.spark.sql.catalyst.expressions._
3131
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
3232
import org.apache.spark.sql.catalyst.expressions.aggregate._
33-
import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects}
33+
import org.apache.spark.sql.catalyst.expressions.objects._
3434
import org.apache.spark.sql.catalyst.plans._
3535
import org.apache.spark.sql.catalyst.plans.logical._
3636
import org.apache.spark.sql.catalyst.rules._
@@ -2145,14 +2145,24 @@ class Analyzer(
21452145
val parameterTypes = ScalaReflection.getParameterTypes(func)
21462146
assert(parameterTypes.length == inputs.length)
21472147

2148+
// TODO: skip null handling for not-nullable primitive inputs after we can completely
2149+
// trust the `nullable` information.
2150+
// (cls, expr) => cls.isPrimitive && expr.nullable
2151+
val needsNullCheck = (cls: Class[_], expr: Expression) =>
2152+
cls.isPrimitive && !expr.isInstanceOf[KnowNotNull]
21482153
val inputsNullCheck = parameterTypes.zip(inputs)
2149-
// TODO: skip null handling for not-nullable primitive inputs after we can completely
2150-
// trust the `nullable` information.
2151-
// .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
2152-
.filter { case (cls, _) => cls.isPrimitive }
2154+
.filter { case (cls, expr) => needsNullCheck(cls, expr) }
21532155
.map { case (_, expr) => IsNull(expr) }
21542156
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
2155-
inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
2157+
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
2158+
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
2159+
// branch of `If` will be called if any of these checked inputs is null. Thus we can
2160+
// prevent this rule from being applied repeatedly.
2161+
val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) =>
2162+
if (needsNullCheck(cls, expr)) KnowNotNull(expr) else expr }
2163+
inputsNullCheck
2164+
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
2165+
.getOrElse(udf)
21562166
}
21572167
}
21582168
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
22+
import org.apache.spark.sql.types.DataType
23+
24+
case class KnowNotNull(child: Expression) extends UnaryExpression {
25+
override def nullable: Boolean = false
26+
override def dataType: DataType = child.dataType
27+
28+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
29+
child.genCode(ctx).copy(isNull = FalseLiteral)
30+
}
31+
32+
override def eval(input: InternalRow): Any = {
33+
child.eval(input)
34+
}
35+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,16 @@ class AnalysisSuite extends AnalysisTest with Matchers {
316316

317317
// only primitive parameter needs special null handling
318318
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
319-
val expected2 = If(IsNull(double), nullResult, udf2)
319+
val expected2 =
320+
If(IsNull(double), nullResult, udf2.copy(children = string :: KnowNotNull(double) :: Nil))
320321
checkUDF(udf2, expected2)
321322

322323
// special null handling should apply to all primitive parameters
323324
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
324325
val expected3 = If(
325326
IsNull(short) || IsNull(double),
326327
nullResult,
327-
udf3)
328+
udf3.copy(children = KnowNotNull(short) :: KnowNotNull(double) :: Nil))
328329
checkUDF(udf3, expected3)
329330

330331
// we can skip special null handling for primitive parameters that are not nullable
@@ -336,10 +337,19 @@ class AnalysisSuite extends AnalysisTest with Matchers {
336337
val expected4 = If(
337338
IsNull(short),
338339
nullResult,
339-
udf4)
340+
udf4.copy(children = KnowNotNull(short) :: double.withNullability(false) :: Nil))
340341
// checkUDF(udf4, expected4)
341342
}
342343

344+
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
345+
val a = testRelation.output(0)
346+
val func = (x: Int, y: Int) => x + y
347+
val udf1 = ScalaUDF(func, IntegerType, a :: a :: Nil)
348+
val udf2 = ScalaUDF(func, IntegerType, a :: udf1 :: Nil)
349+
val plan = Project(Alias(udf2, "")() :: Nil, testRelation)
350+
comparePlans(plan.analyze, plan.analyze.analyze)
351+
}
352+
343353
test("SPARK-11863 mixture of aliases and real columns in order by clause - tpcds 19,55,71") {
344354
val a = testRelation2.output(0)
345355
val c = testRelation2.output(2)

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql
2020
import org.apache.spark.sql.api.java._
2121
import org.apache.spark.sql.catalyst.plans.logical.Project
2222
import org.apache.spark.sql.execution.command.ExplainCommand
23-
import org.apache.spark.sql.functions.udf
23+
import org.apache.spark.sql.functions.{lit, udf}
2424
import org.apache.spark.sql.test.SharedSQLContext
2525
import org.apache.spark.sql.test.SQLTestData._
2626
import org.apache.spark.sql.types.{DataTypes, DoubleType}
@@ -324,4 +324,33 @@ class UDFSuite extends QueryTest with SharedSQLContext {
324324
assert(outputStream.toString.contains("UDF:f(a._1 AS `_1`)"))
325325
}
326326
}
327+
328+
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
329+
val udf1 = udf({(x: Int, y: Int) => x + y})
330+
val df = spark.range(0, 3).toDF("a")
331+
.withColumn("b", udf1($"a", udf1($"a", lit(10))))
332+
.withColumn("c", udf1($"a", lit(null)))
333+
val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed
334+
335+
comparePlans(df.logicalPlan, plan)
336+
checkAnswer(
337+
df,
338+
Seq(
339+
Row(0, 10, null),
340+
Row(1, 12, null),
341+
Row(2, 14, null)))
342+
}
343+
344+
test("SPARK-24891 Fix HandleNullInputsForUDF rule - with table") {
345+
withTable("x") {
346+
Seq((1, "2"), (2, "4")).toDF("a", "b").write.format("json").saveAsTable("x")
347+
sql("insert into table x values(3, null)")
348+
sql("insert into table x values(null, '4')")
349+
spark.udf.register("f", (a: Int, b: String) => a + b)
350+
val df = spark.sql("SELECT f(a, b) FROM x")
351+
val plan = spark.sessionState.executePlan(df.logicalPlan).analyzed
352+
comparePlans(df.logicalPlan, plan)
353+
checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null)))
354+
}
355+
}
327356
}

0 commit comments

Comments
 (0)