@@ -38,106 +38,114 @@ import org.apache.spark.sql.types.DataType
38
38
* @since 1.3.0
39
39
*/
40
40
@ Stable
41
- sealed trait UserDefinedFunction {
41
+ case class UserDefinedFunction protected [sql] (
42
+ f : AnyRef ,
43
+ dataType : DataType ,
44
+ inputTypes : Option [Seq [DataType ]]) {
45
+
46
+ private var _nameOption : Option [String ] = None
47
+ private var _nullable : Boolean = true
48
+ private var _deterministic : Boolean = true
49
+
50
+ // This is a `var` instead of in the constructor for backward compatibility of this case class.
51
+ // TODO: revisit this case class in Spark 3.0, and narrow down the public surface.
52
+ private [sql] var nullableTypes : Option [Seq [Boolean ]] = None
42
53
43
54
/**
44
55
* Returns true when the UDF can return a nullable value.
45
56
*
46
57
* @since 2.3.0
47
58
*/
48
- def nullable : Boolean
59
+ def nullable : Boolean = _nullable
49
60
50
61
/**
51
62
* Returns true iff the UDF is deterministic, i.e. the UDF produces the same output given the same
52
63
* input.
53
64
*
54
65
* @since 2.3.0
55
66
*/
56
- def deterministic : Boolean
67
+ def deterministic : Boolean = _deterministic
57
68
58
69
/**
59
70
* Returns an expression that invokes the UDF, using the given arguments.
60
71
*
61
72
* @since 1.3.0
62
73
*/
63
74
@ scala.annotation.varargs
64
- def apply (exprs : Column * ): Column
65
-
66
- /**
67
- * Updates UserDefinedFunction with a given name.
68
- *
69
- * @since 2.3.0
70
- */
71
- def withName (name : String ): UserDefinedFunction
72
-
73
- /**
74
- * Updates UserDefinedFunction to non-nullable.
75
- *
76
- * @since 2.3.0
77
- */
78
- def asNonNullable (): UserDefinedFunction
79
-
80
- /**
81
- * Updates UserDefinedFunction to nondeterministic.
82
- *
83
- * @since 2.3.0
84
- */
85
- def asNondeterministic (): UserDefinedFunction
86
- }
87
-
88
- private [sql] case class SparkUserDefinedFunction (
89
- f : AnyRef ,
90
- dataType : DataType ,
91
- inputTypes : Option [Seq [DataType ]],
92
- nullableTypes : Option [Seq [Boolean ]],
93
- name : Option [String ] = None ,
94
- nullable : Boolean = true ,
95
- deterministic : Boolean = true ) extends UserDefinedFunction {
96
-
97
- @ scala.annotation.varargs
98
- override def apply (exprs : Column * ): Column = {
75
+ def apply (exprs : Column * ): Column = {
99
76
// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
100
77
// and `nullableTypes` is always set.
78
+ if (nullableTypes.isEmpty) {
79
+ nullableTypes = Some (ScalaReflection .getParameterTypeNullability(f))
80
+ }
101
81
if (inputTypes.isDefined) {
102
82
assert(inputTypes.get.length == nullableTypes.get.length)
103
83
}
104
84
105
- val inputsNullSafe = nullableTypes.getOrElse {
106
- ScalaReflection .getParameterTypeNullability(f)
107
- }
108
-
109
85
Column (ScalaUDF (
110
86
f,
111
87
dataType,
112
88
exprs.map(_.expr),
113
- inputsNullSafe ,
89
+ nullableTypes.get ,
114
90
inputTypes.getOrElse(Nil ),
115
- udfName = name,
116
- nullable = nullable,
117
- udfDeterministic = deterministic))
91
+ udfName = _nameOption,
92
+ nullable = _nullable,
93
+ udfDeterministic = _deterministic))
94
+ }
95
+
96
+ private def copyAll (): UserDefinedFunction = {
97
+ val udf = copy()
98
+ udf._nameOption = _nameOption
99
+ udf._nullable = _nullable
100
+ udf._deterministic = _deterministic
101
+ udf.nullableTypes = nullableTypes
102
+ udf
118
103
}
119
104
120
- override def withName (name : String ): UserDefinedFunction = {
121
- copy(name = Option (name))
105
+ /**
106
+ * Updates UserDefinedFunction with a given name.
107
+ *
108
+ * @since 2.3.0
109
+ */
110
+ def withName (name : String ): UserDefinedFunction = {
111
+ val udf = copyAll()
112
+ udf._nameOption = Option (name)
113
+ udf
122
114
}
123
115
124
- override def asNonNullable (): UserDefinedFunction = {
116
+ /**
117
+ * Updates UserDefinedFunction to non-nullable.
118
+ *
119
+ * @since 2.3.0
120
+ */
121
+ def asNonNullable (): UserDefinedFunction = {
125
122
if (! nullable) {
126
123
this
127
124
} else {
128
- copy(nullable = false )
125
+ val udf = copyAll()
126
+ udf._nullable = false
127
+ udf
129
128
}
130
129
}
131
130
132
- override def asNondeterministic (): UserDefinedFunction = {
133
- if (! deterministic) {
131
+ /**
132
+ * Updates UserDefinedFunction to nondeterministic.
133
+ *
134
+ * @since 2.3.0
135
+ */
136
+ def asNondeterministic (): UserDefinedFunction = {
137
+ if (! _deterministic) {
134
138
this
135
139
} else {
136
- copy(deterministic = false )
140
+ val udf = copyAll()
141
+ udf._deterministic = false
142
+ udf
137
143
}
138
144
}
139
145
}
140
146
147
+ // We have to use a name different than `UserDefinedFunction` here, to avoid breaking the binary
148
+ // compatibility of the auto-generate UserDefinedFunction object.
141
149
private [sql] object SparkUserDefinedFunction {
142
150
143
151
def create (
@@ -149,7 +157,8 @@ private[sql] object SparkUserDefinedFunction {
149
157
} else {
150
158
Some (inputSchemas.map(_.get.dataType))
151
159
}
152
- val nullableTypes = Some (inputSchemas.map(_.map(_.nullable).getOrElse(true )))
153
- SparkUserDefinedFunction (f, dataType, inputTypes, nullableTypes)
160
+ val udf = new UserDefinedFunction (f, dataType, inputTypes)
161
+ udf.nullableTypes = Some (inputSchemas.map(_.map(_.nullable).getOrElse(true )))
162
+ udf
154
163
}
155
164
}
0 commit comments