Skip to content

Commit b0d256d

Browse files
author
Robert Kruszewski
committed
Revert "[SPARK-26580][SQL] remove Scala 2.11 hack for Scala UDF"
This reverts commit 1f1d98c.
1 parent 33c1522 commit b0d256d

File tree

5 files changed

+40
-21
lines changed

5 files changed

+40
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,23 @@ trait ScalaReflection extends Logging {
960960
tpe.dealias.erasure.typeSymbol.asClass.fullName
961961
}
962962

963+
/**
964+
* Returns the nullability of the input parameter types of the scala function object.
965+
*
966+
* Note that this only works with Scala 2.11, and the information returned may be inaccurate if
967+
* used with a different Scala version.
968+
*/
969+
def getParameterTypeNullability(func: AnyRef): Seq[Boolean] = {
970+
if (!Properties.versionString.contains("2.11")) {
971+
logWarning(s"Scala ${Properties.versionString} cannot get type nullability correctly via " +
972+
"reflection, thus Spark cannot add proper input null check for UDF. To avoid this " +
973+
"problem, use the typed UDF interfaces instead.")
974+
}
975+
val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
976+
assert(methods.length == 1)
977+
methods.head.getParameterTypes.map(!_.isPrimitive)
978+
}
979+
963980
/**
964981
* Returns the parameter names and types for the primary constructor of this type.
965982
*

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ case class ScalaUDF(
5454
udfDeterministic: Boolean = true)
5555
extends Expression with NonSQLExpression with UserDefinedExpression {
5656

57+
// The constructor for SPARK 2.1 and 2.2
58+
def this(
59+
function: AnyRef,
60+
dataType: DataType,
61+
children: Seq[Expression],
62+
inputTypes: Seq[DataType],
63+
udfName: Option[String]) = {
64+
this(
65+
function, dataType, children, ScalaReflection.getParameterTypeNullability(function),
66+
inputTypes, udfName, nullable = true, udfDeterministic = true)
67+
}
68+
5769
override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
5870

5971
override def toString: String =

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,17 @@ private[sql] case class SparkUserDefinedFunction(
102102
// It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type
103103
// check and null check for them.
104104
val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
105-
val inputsNullSafe = inputSchemas.map(_.map(_.nullable).getOrElse(true))
105+
106+
val inputsNullSafe = if (inputSchemas.isEmpty) {
107+
// This is for backward compatibility of `functions.udf(AnyRef, DataType)`. We need to
108+
// do reflection of the lambda function object and see if its arguments are nullable or not.
109+
// This doesn't work for Scala 2.12 and we should consider removing this workaround, as Spark
110+
// uses Scala 2.12 by default since 3.0.
111+
ScalaReflection.getParameterTypeNullability(f)
112+
} else {
113+
inputSchemas.map(_.map(_.nullable).getOrElse(true))
114+
}
115+
106116
ScalaUDF(
107117
f,
108118
dataType,

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4355,13 +4355,6 @@ object functions {
43554355
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
43564356
* API `UserDefinedFunction.asNondeterministic()`.
43574357
*
4358-
* Note that, although the Scala closure can have primitive-type function argument, it doesn't
4359-
* work well with null values. Because the Scala closure is passed in as Any type, there is no
4360-
* type information for the function arguments. Without the type information, Spark may blindly
4361-
* pass null to the Scala closure with primitive-type argument, and the closure will see the
4362-
* default value of the Java type for the null argument, e.g. `udf((x: Int) => x, IntegerType)`,
4363-
* the result is 0 for null input.
4364-
*
43654358
* @param f A closure in Scala
43664359
* @param dataType The output data type of the UDF
43674360
*

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -423,19 +423,6 @@ class UDFSuite extends QueryTest with SharedSQLContext {
423423
}
424424
}
425425

426-
test("SPARK-25044 Verify null input handling for primitive types - with udf(Any, DataType)") {
427-
val f = udf((x: Int) => x, IntegerType)
428-
checkAnswer(
429-
Seq(new Integer(1), null).toDF("x").select(f($"x")),
430-
Row(1) :: Row(0) :: Nil)
431-
432-
val f2 = udf((x: Double) => x, DoubleType)
433-
checkAnswer(
434-
Seq(new java.lang.Double(1.1), null).toDF("x").select(f2($"x")),
435-
Row(1.1) :: Row(0.0) :: Nil)
436-
437-
}
438-
439426
test("SPARK-26308: udf with decimal") {
440427
val df1 = spark.createDataFrame(
441428
sparkContext.parallelize(Seq(Row(new BigDecimal("2011000000000002456556")))),

0 commit comments

Comments
 (0)