Skip to content

Commit 1fd59c1

Browse files
srowencloud-fan
authored andcommitted
[WIP][SPARK-25044][SQL] (take 2) Address translation of LMF closure primitive args to Object in Scala 2.12
## What changes were proposed in this pull request? Alternative take on apache#22063 that does not introduce udfInternal. Resolve issue with inferring func types in 2.12 by instead using info captured when UDF is registered -- capturing which types are nullable (i.e. not primitive) ## How was this patch tested? Existing tests. Closes apache#22259 from srowen/SPARK-25044.2. Authored-by: Sean Owen <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 82c18c2 commit 1fd59c1

File tree

9 files changed

+133
-116
lines changed

9 files changed

+133
-116
lines changed

project/MimaExcludes.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ object MimaExcludes {
3636

3737
// Exclude rules for 2.4.x
3838
lazy val v24excludes = v23excludes ++ Seq(
39+
// [SPARK-25044] Address translation of LMF closure primitive args to Object in Scala 2.12
40+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"),
41+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"),
42+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"),
43+
3944
// [SPARK-24296][CORE] Replicate large blocks as a stream.
4045
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"),
4146
// [SPARK-23528] Add numIter to ClusteringSummary

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -932,15 +932,6 @@ trait ScalaReflection {
932932
tpe.dealias.erasure.typeSymbol.asClass.fullName
933933
}
934934

935-
/**
936-
* Returns classes of input parameters of scala function object.
937-
*/
938-
def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
939-
val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
940-
assert(methods.length == 1)
941-
methods.head.getParameterTypes
942-
}
943-
944935
/**
945936
* Returns the parameter names and types for the primary constructor of this type.
946937
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,28 +2149,34 @@ class Analyzer(
21492149

21502150
case p => p transformExpressionsUp {
21512151

2152-
case udf @ ScalaUDF(func, _, inputs, _, _, _, _) =>
2153-
val parameterTypes = ScalaReflection.getParameterTypes(func)
2154-
assert(parameterTypes.length == inputs.length)
2155-
2156-
// TODO: skip null handling for not-nullable primitive inputs after we can completely
2157-
// trust the `nullable` information.
2158-
// (cls, expr) => cls.isPrimitive && expr.nullable
2159-
val needsNullCheck = (cls: Class[_], expr: Expression) =>
2160-
cls.isPrimitive && !expr.isInstanceOf[KnownNotNull]
2161-
val inputsNullCheck = parameterTypes.zip(inputs)
2162-
.filter { case (cls, expr) => needsNullCheck(cls, expr) }
2163-
.map { case (_, expr) => IsNull(expr) }
2164-
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
2165-
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
2166-
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
2167-
// branch of `If` will be called if any of these checked inputs is null. Thus we can
2168-
// prevent this rule from being applied repeatedly.
2169-
val newInputs = parameterTypes.zip(inputs).map{ case (cls, expr) =>
2170-
if (needsNullCheck(cls, expr)) KnownNotNull(expr) else expr }
2171-
inputsNullCheck
2172-
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
2173-
.getOrElse(udf)
2152+
case udf@ScalaUDF(func, _, inputs, _, _, _, _, nullableTypes) =>
2153+
if (nullableTypes.isEmpty) {
2154+
// If no nullability info is available, do nothing. No fields will be specially
2155+
// checked for null in the plan. If nullability info is incorrect, the results
2156+
// of the UDF could be wrong.
2157+
udf
2158+
} else {
2159+
// Otherwise, add special handling of null for fields that can't accept null.
2160+
// The result of operations like this, when passed null, is generally to return null.
2161+
assert(nullableTypes.length == inputs.length)
2162+
2163+
// TODO: skip null handling for not-nullable primitive inputs after we can completely
2164+
// trust the `nullable` information.
2165+
val inputsNullCheck = nullableTypes.zip(inputs)
2166+
.filter { case (nullable, _) => !nullable }
2167+
.map { case (_, expr) => IsNull(expr) }
2168+
.reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
2169+
// Once we add an `If` check above the udf, it is safe to mark those checked inputs
2170+
// as not nullable (i.e., wrap them with `KnownNotNull`), because the null-returning
2171+
// branch of `If` will be called if any of these checked inputs is null. Thus we can
2172+
// prevent this rule from being applied repeatedly.
2173+
val newInputs = nullableTypes.zip(inputs).map { case (nullable, expr) =>
2174+
if (nullable) expr else KnownNotNull(expr)
2175+
}
2176+
inputsNullCheck
2177+
.map(If(_, Literal.create(null, udf.dataType), udf.copy(children = newInputs)))
2178+
.getOrElse(udf)
2179+
}
21742180
}
21752181
}
21762182
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import org.apache.spark.sql.types.DataType
3939
* @param nullable True if the UDF can return null value.
4040
* @param udfDeterministic True if the UDF is deterministic. Deterministic UDF returns same result
4141
* each time it is invoked with a particular input.
42+
* @param nullableTypes which of the inputTypes are nullable (i.e. not primitive)
4243
*/
4344
case class ScalaUDF(
4445
function: AnyRef,
@@ -47,7 +48,8 @@ case class ScalaUDF(
4748
inputTypes: Seq[DataType] = Nil,
4849
udfName: Option[String] = None,
4950
nullable: Boolean = true,
50-
udfDeterministic: Boolean = true)
51+
udfDeterministic: Boolean = true,
52+
nullableTypes: Seq[Boolean] = Nil)
5153
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
5254

5355
// The constructor for SPARK 2.1 and 2.2
@@ -58,7 +60,8 @@ case class ScalaUDF(
5860
inputTypes: Seq[DataType],
5961
udfName: Option[String]) = {
6062
this(
61-
function, dataType, children, inputTypes, udfName, nullable = true, udfDeterministic = true)
63+
function, dataType, children, inputTypes, udfName, nullable = true,
64+
udfDeterministic = true, nullableTypes = Nil)
6265
}
6366

6467
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -261,23 +261,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
261261
}
262262
}
263263

264-
test("get parameter type from a function object") {
265-
val primitiveFunc = (i: Int, j: Long) => "x"
266-
val primitiveTypes = getParameterTypes(primitiveFunc)
267-
assert(primitiveTypes.forall(_.isPrimitive))
268-
assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))
269-
270-
val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
271-
val boxedTypes = getParameterTypes(boxedFunc)
272-
assert(boxedTypes.forall(!_.isPrimitive))
273-
assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long]))
274-
275-
val anyFunc = (i: Any, j: AnyRef) => "x"
276-
val anyTypes = getParameterTypes(anyFunc)
277-
assert(anyTypes.forall(!_.isPrimitive))
278-
assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
279-
}
280-
281264
test("SPARK-15062: Get correct serializer for List[_]") {
282265
val list = List(1, 2, 3)
283266
val serializer = serializerFor[List[Int]](BoundReference(

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,15 @@ class AnalysisSuite extends AnalysisTest with Matchers {
317317
checkUDF(udf1, expected1)
318318

319319
// only primitive parameter needs special null handling
320-
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
320+
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
321+
nullableTypes = true :: false :: Nil)
321322
val expected2 =
322323
If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
323324
checkUDF(udf2, expected2)
324325

325326
// special null handling should apply to all primitive parameters
326-
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
327+
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
328+
nullableTypes = false :: false :: Nil)
327329
val expected3 = If(
328330
IsNull(short) || IsNull(double),
329331
nullResult,
@@ -335,7 +337,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
335337
val udf4 = ScalaUDF(
336338
(s: Short, d: Double) => "x",
337339
StringType,
338-
short :: double.withNullability(false) :: Nil)
340+
short :: double.withNullability(false) :: Nil,
341+
nullableTypes = false :: false :: Nil)
339342
val expected4 = If(
340343
IsNull(short),
341344
nullResult,

0 commit comments

Comments
 (0)