Skip to content

Commit 1423024

Browse files
author
Robert Kruszewski
committed
Revert "[SPARK-26323][SQL] Scala UDF should still check input types even if some inputs are of type Any"
This reverts commit 72a572f.
1 parent b0d256d commit 1423024

File tree

7 files changed

+184
-175
lines changed

7 files changed

+184
-175
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -882,18 +882,7 @@ object TypeCoercion {
882882

883883
case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
884884
val children = udf.children.zip(udf.inputTypes).map { case (in, expected) =>
885-
// Currently Scala UDF will only expect `AnyDataType` at top level, so this trick works.
886-
// In the future we should create types like `AbstractArrayType`, so that Scala UDF can
887-
// accept inputs of array type of arbitrary element type.
888-
if (expected == AnyDataType) {
889-
in
890-
} else {
891-
implicitCast(
892-
in,
893-
udfInputToCastType(in.dataType, expected.asInstanceOf[DataType])
894-
).getOrElse(in)
895-
}
896-
885+
implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in)
897886
}
898887
udf.withNewChildren(children)
899888
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.SparkException
2121
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
2222
import org.apache.spark.sql.catalyst.expressions.codegen._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
24-
import org.apache.spark.sql.types.{AbstractDataType, DataType}
24+
import org.apache.spark.sql.types.DataType
2525

2626
/**
2727
* User-defined function.
@@ -48,7 +48,7 @@ case class ScalaUDF(
4848
dataType: DataType,
4949
children: Seq[Expression],
5050
inputsNullSafe: Seq[Boolean],
51-
inputTypes: Seq[AbstractDataType] = Nil,
51+
inputTypes: Seq[DataType] = Nil,
5252
udfName: Option[String] = None,
5353
nullable: Boolean = true,
5454
udfDeterministic: Boolean = true)

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ private[sql] object TypeCollection {
9696
/**
9797
* An `AbstractDataType` that matches any concrete data types.
9898
*/
99-
protected[sql] object AnyDataType extends AbstractDataType with Serializable {
99+
protected[sql] object AnyDataType extends AbstractDataType {
100100

101101
// Note that since AnyDataType matches any concrete types, defaultConcreteType should never
102102
// be invoked.

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 120 additions & 96 deletions
Large diffs are not rendered by default.

sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.sql.expressions
2020
import org.apache.spark.annotation.Stable
2121
import org.apache.spark.sql.Column
2222
import org.apache.spark.sql.catalyst.ScalaReflection
23-
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
24-
import org.apache.spark.sql.types.{AnyDataType, DataType}
23+
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
24+
import org.apache.spark.sql.types.DataType
2525

2626
/**
2727
* A user-defined function. To create one, use the `udf` functions in `functions`.
@@ -88,59 +88,68 @@ sealed abstract class UserDefinedFunction {
8888
private[sql] case class SparkUserDefinedFunction(
8989
f: AnyRef,
9090
dataType: DataType,
91-
inputSchemas: Seq[Option[ScalaReflection.Schema]],
91+
inputTypes: Option[Seq[DataType]],
92+
nullableTypes: Option[Seq[Boolean]],
9293
name: Option[String] = None,
9394
nullable: Boolean = true,
9495
deterministic: Boolean = true) extends UserDefinedFunction {
9596

9697
@scala.annotation.varargs
9798
override def apply(exprs: Column*): Column = {
98-
Column(createScalaUDF(exprs.map(_.expr)))
99-
}
100-
101-
private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = {
102-
// It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type
103-
// check and null check for them.
104-
val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
99+
// TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()`
100+
// and `nullableTypes` is always set.
101+
if (inputTypes.isDefined) {
102+
assert(inputTypes.get.length == nullableTypes.get.length)
103+
}
105104

106-
val inputsNullSafe = if (inputSchemas.isEmpty) {
107-
// This is for backward compatibility of `functions.udf(AnyRef, DataType)`. We need to
108-
// do reflection of the lambda function object and see if its arguments are nullable or not.
109-
// This doesn't work for Scala 2.12 and we should consider removing this workaround, as Spark
110-
// uses Scala 2.12 by default since 3.0.
105+
val inputsNullSafe = nullableTypes.getOrElse {
111106
ScalaReflection.getParameterTypeNullability(f)
112-
} else {
113-
inputSchemas.map(_.map(_.nullable).getOrElse(true))
114107
}
115108

116-
ScalaUDF(
109+
Column(ScalaUDF(
117110
f,
118111
dataType,
119-
exprs,
112+
exprs.map(_.expr),
120113
inputsNullSafe,
121-
inputTypes,
114+
inputTypes.getOrElse(Nil),
122115
udfName = name,
123116
nullable = nullable,
124-
udfDeterministic = deterministic)
117+
udfDeterministic = deterministic))
125118
}
126119

127-
override def withName(name: String): SparkUserDefinedFunction = {
120+
override def withName(name: String): UserDefinedFunction = {
128121
copy(name = Option(name))
129122
}
130123

131-
override def asNonNullable(): SparkUserDefinedFunction = {
124+
override def asNonNullable(): UserDefinedFunction = {
132125
if (!nullable) {
133126
this
134127
} else {
135128
copy(nullable = false)
136129
}
137130
}
138131

139-
override def asNondeterministic(): SparkUserDefinedFunction = {
132+
override def asNondeterministic(): UserDefinedFunction = {
140133
if (!deterministic) {
141134
this
142135
} else {
143136
copy(deterministic = false)
144137
}
145138
}
146139
}
140+
141+
private[sql] object SparkUserDefinedFunction {
142+
143+
def create(
144+
f: AnyRef,
145+
dataType: DataType,
146+
inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = {
147+
val inputTypes = if (inputSchemas.contains(None)) {
148+
None
149+
} else {
150+
Some(inputSchemas.map(_.get.dataType))
151+
}
152+
val nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
153+
SparkUserDefinedFunction(f, dataType, inputTypes, nullableTypes)
154+
}
155+
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3979,7 +3979,7 @@ object functions {
39793979
|def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
39803980
| val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
39813981
| val inputSchemas = $inputSchemas
3982-
| val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
3982+
| val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
39833983
| if (nullable) udf else udf.asNonNullable()
39843984
|}""".stripMargin)
39853985
}
@@ -4002,7 +4002,7 @@ object functions {
40024002
| */
40034003
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
40044004
| val func = f$anyCast.call($anyParams)
4005-
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
4005+
| SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
40064006
|}""".stripMargin)
40074007
}
40084008
@@ -4024,7 +4024,7 @@ object functions {
40244024
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
40254025
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40264026
val inputSchemas = Nil
4027-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4027+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
40284028
if (nullable) udf else udf.asNonNullable()
40294029
}
40304030

@@ -4040,7 +4040,7 @@ object functions {
40404040
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
40414041
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40424042
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil
4043-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4043+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
40444044
if (nullable) udf else udf.asNonNullable()
40454045
}
40464046

@@ -4056,7 +4056,7 @@ object functions {
40564056
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
40574057
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40584058
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil
4059-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4059+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
40604060
if (nullable) udf else udf.asNonNullable()
40614061
}
40624062

@@ -4072,7 +4072,7 @@ object functions {
40724072
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
40734073
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40744074
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil
4075-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4075+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
40764076
if (nullable) udf else udf.asNonNullable()
40774077
}
40784078

@@ -4088,7 +4088,7 @@ object functions {
40884088
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
40894089
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
40904090
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil
4091-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4091+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
40924092
if (nullable) udf else udf.asNonNullable()
40934093
}
40944094

@@ -4104,7 +4104,7 @@ object functions {
41044104
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
41054105
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
41064106
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil
4107-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4107+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
41084108
if (nullable) udf else udf.asNonNullable()
41094109
}
41104110

@@ -4120,7 +4120,7 @@ object functions {
41204120
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
41214121
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
41224122
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil
4123-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4123+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
41244124
if (nullable) udf else udf.asNonNullable()
41254125
}
41264126

@@ -4136,7 +4136,7 @@ object functions {
41364136
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
41374137
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
41384138
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil
4139-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4139+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
41404140
if (nullable) udf else udf.asNonNullable()
41414141
}
41424142

@@ -4152,7 +4152,7 @@ object functions {
41524152
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
41534153
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
41544154
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil
4155-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4155+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
41564156
if (nullable) udf else udf.asNonNullable()
41574157
}
41584158

@@ -4168,7 +4168,7 @@ object functions {
41684168
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
41694169
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
41704170
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil
4171-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4171+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
41724172
if (nullable) udf else udf.asNonNullable()
41734173
}
41744174

@@ -4184,7 +4184,7 @@ object functions {
41844184
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
41854185
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
41864186
val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil
4187-
val udf = SparkUserDefinedFunction(f, dataType, inputSchemas)
4187+
val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
41884188
if (nullable) udf else udf.asNonNullable()
41894189
}
41904190

@@ -4203,7 +4203,7 @@ object functions {
42034203
*/
42044204
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
42054205
val func = f.asInstanceOf[UDF0[Any]].call()
4206-
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
4206+
SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = Seq.fill(0)(None))
42074207
}
42084208

42094209
/**
@@ -4217,7 +4217,7 @@ object functions {
42174217
*/
42184218
def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = {
42194219
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
4220-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(1)(None))
4220+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(1)(None))
42214221
}
42224222

42234223
/**
@@ -4231,7 +4231,7 @@ object functions {
42314231
*/
42324232
def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = {
42334233
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
4234-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(2)(None))
4234+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(2)(None))
42354235
}
42364236

42374237
/**
@@ -4245,7 +4245,7 @@ object functions {
42454245
*/
42464246
def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = {
42474247
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
4248-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(3)(None))
4248+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(3)(None))
42494249
}
42504250

42514251
/**
@@ -4259,7 +4259,7 @@ object functions {
42594259
*/
42604260
def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = {
42614261
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
4262-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(4)(None))
4262+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(4)(None))
42634263
}
42644264

42654265
/**
@@ -4273,7 +4273,7 @@ object functions {
42734273
*/
42744274
def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
42754275
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
4276-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(5)(None))
4276+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(5)(None))
42774277
}
42784278

42794279
/**
@@ -4287,7 +4287,7 @@ object functions {
42874287
*/
42884288
def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
42894289
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4290-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(6)(None))
4290+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(6)(None))
42914291
}
42924292

42934293
/**
@@ -4301,7 +4301,7 @@ object functions {
43014301
*/
43024302
def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
43034303
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4304-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(7)(None))
4304+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(7)(None))
43054305
}
43064306

43074307
/**
@@ -4315,7 +4315,7 @@ object functions {
43154315
*/
43164316
def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
43174317
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4318-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(8)(None))
4318+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(8)(None))
43194319
}
43204320

43214321
/**
@@ -4329,7 +4329,7 @@ object functions {
43294329
*/
43304330
def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
43314331
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4332-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(9)(None))
4332+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(9)(None))
43334333
}
43344334

43354335
/**
@@ -4343,7 +4343,7 @@ object functions {
43434343
*/
43444344
def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = {
43454345
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
4346-
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(10)(None))
4346+
SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(10)(None))
43474347
}
43484348

43494349
// scalastyle:on parameter.number
@@ -4362,7 +4362,9 @@ object functions {
43624362
* @since 2.0.0
43634363
*/
43644364
def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
4365-
SparkUserDefinedFunction(f, dataType, inputSchemas = Nil)
4365+
// TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently
4366+
// unavailable. We may need to create type-safe overloaded versions of udf() methods.
4367+
SparkUserDefinedFunction(f, dataType, inputTypes = None, nullableTypes = None)
43664368
}
43674369

43684370
/**

0 commit comments

Comments
 (0)