Skip to content

Commit e595072

Browse files
author
Robert Kruszewski
committed
Revert "[SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction)"
This reverts commit 39617cb.
1 parent 8735a08 commit e595072

File tree

3 files changed

+65
-57
lines changed

3 files changed

+65
-57
lines changed

project/MimaExcludes.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ object MimaExcludes {
186186
// [SPARK-26616][MLlib] Expose document frequency in IDFModel
187187
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"),
188188
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf")
189-
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.expressions.UserDefinedFunction")
190189
)
191190

192191
// Exclude rules for 2.4.x

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,106 +38,114 @@ import org.apache.spark.sql.types.DataType
3838
* @since 1.3.0
3939
*/
4040
@Stable
41-
sealed trait UserDefinedFunction {
41+
case class UserDefinedFunction protected[sql] (
42+
f: AnyRef,
43+
dataType: DataType,
44+
inputTypes: Option[Seq[DataType]]) {
45+
46+
private var _nameOption: Option[String] = None
47+
private var _nullable: Boolean = true
48+
private var _deterministic: Boolean = true
49+
50+
// This is a `var` instead of in the constructor for backward compatibility of this case class.
51+
// TODO: revisit this case class in Spark 3.0, and narrow down the public surface.
52+
private[sql] var nullableTypes: Option[Seq[Boolean]] = None
4253

4354
/**
4455
* Returns true when the UDF can return a nullable value.
4556
*
4657
* @since 2.3.0
4758
*/
48-
def nullable: Boolean
59+
def nullable: Boolean = _nullable
4960

5061
/**
5162
* Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
5263
* input.
5364
*
5465
* @since 2.3.0
5566
*/
56-
def deterministic: Boolean
67+
def deterministic: Boolean = _deterministic
5768

5869
/**
5970
* Returns an expression that invokes the UDF, using the given arguments.
6071
*
6172
* @since 1.3.0
6273
*/
6374
@scala.annotation.varargs
64-
def apply(exprs: Column*): Column
65-
66-
/**
67-
* Updates UserDefinedFunction with a given name.
68-
*
69-
* @since 2.3.0
70-
*/
71-
def withName(name: String): UserDefinedFunction
72-
73-
/**
74-
* Updates UserDefinedFunction to non-nullable.
75-
*
76-
* @since 2.3.0
77-
*/
78-
def asNonNullable(): UserDefinedFunction
79-
80-
/**
81-
* Updates UserDefinedFunction to nondeterministic.
82-
*
83-
* @since 2.3.0
84-
*/
85-
def asNondeterministic(): UserDefinedFunction
86-
}
87-
88-
private[sql] case class SparkUserDefinedFunction(
89-
f: AnyRef,
90-
dataType: DataType,
91-
inputTypes: Option[Seq[DataType]],
92-
nullableTypes: Option[Seq[Boolean]],
93-
name: Option[String] = None,
94-
nullable: Boolean = true,
95-
deterministic: Boolean = true) extends UserDefinedFunction {
96-
97-
@scala.annotation.varargs
98-
override def apply(exprs: Column*): Column = {
75+
def apply(exprs: Column*): Column = {
9976
// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
10077
// and `nullableTypes` is always set.
78+
if (nullableTypes.isEmpty) {
79+
nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f))
80+
}
10181
if (inputTypes.isDefined) {
10282
assert(inputTypes.get.length == nullableTypes.get.length)
10383
}
10484

105-
val inputsNullSafe = nullableTypes.getOrElse {
106-
ScalaReflection.getParameterTypeNullability(f)
107-
}
108-
10985
Column(ScalaUDF(
11086
f,
11187
dataType,
11288
exprs.map(_.expr),
113-
inputsNullSafe,
89+
nullableTypes.get,
11490
inputTypes.getOrElse(Nil),
115-
udfName = name,
116-
nullable = nullable,
117-
udfDeterministic = deterministic))
91+
udfName = _nameOption,
92+
nullable = _nullable,
93+
udfDeterministic = _deterministic))
94+
}
95+
96+
private def copyAll(): UserDefinedFunction = {
97+
val udf = copy()
98+
udf._nameOption = _nameOption
99+
udf._nullable = _nullable
100+
udf._deterministic = _deterministic
101+
udf.nullableTypes = nullableTypes
102+
udf
118103
}
119104

120-
override def withName(name: String): UserDefinedFunction = {
121-
copy(name = Option(name))
105+
/**
106+
* Updates UserDefinedFunction with a given name.
107+
*
108+
* @since 2.3.0
109+
*/
110+
def withName(name: String): UserDefinedFunction = {
111+
val udf = copyAll()
112+
udf._nameOption = Option(name)
113+
udf
122114
}
123115

124-
override def asNonNullable(): UserDefinedFunction = {
116+
/**
117+
* Updates UserDefinedFunction to non-nullable.
118+
*
119+
* @since 2.3.0
120+
*/
121+
def asNonNullable(): UserDefinedFunction = {
125122
if (!nullable) {
126123
this
127124
} else {
128-
copy(nullable = false)
125+
val udf = copyAll()
126+
udf._nullable = false
127+
udf
129128
}
130129
}
131130

132-
override def asNondeterministic(): UserDefinedFunction = {
133-
if (!deterministic) {
131+
/**
132+
* Updates UserDefinedFunction to nondeterministic.
133+
*
134+
* @since 2.3.0
135+
*/
136+
def asNondeterministic(): UserDefinedFunction = {
137+
if (!_deterministic) {
134138
this
135139
} else {
136-
copy(deterministic = false)
140+
val udf = copyAll()
141+
udf._deterministic = false
142+
udf
137143
}
138144
}
139145
}
140146

147+
// We have to use a name different than `UserDefinedFunction` here, to avoid breaking the binary
148+
// compatibility of the auto-generate UserDefinedFunction object.
141149
private[sql] object SparkUserDefinedFunction {
142150

143151
def create(
@@ -149,7 +157,8 @@ private[sql] object SparkUserDefinedFunction {
149157
} else {
150158
Some(inputSchemas.map(_.get.dataType))
151159
}
152-
val nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
153-
SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes)
160+
val udf = new UserDefinedFunction(f, dataType, inputTypes)
161+
udf.nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
162+
udf
154163
}
155164
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4364,7 +4364,7 @@ object functions {
43644364
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
43654365
// TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently
43664366
// unavailable. We may need to create type-safe overloaded versions of udf() methods.
4367-
SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes = None)
4367+
new UserDefinedFunction(f, dataType, inputTypes = None)
43684368
}
43694369

43704370
/**

0 commit comments

Comments
 (0)