Skip to content

Commit 8fcbda9

Browse files
gengliangwanggatorsmile
authored andcommitted
[SPARK-21848][SQL] Add trait UserDefinedExpression to identify user-defined functions
## What changes were proposed in this pull request? Add trait UserDefinedExpression to identify user-defined functions. UDF can be expensive. In optimizer we may need to avoid executing UDF multiple times. E.g. ```scala table.select(UDF as 'a).select('a, ('a + 1) as 'b) ``` If UDF is expensive in this case, optimizer should not collapse the project to ```scala table.select(UDF as 'a, (UDF+1) as 'b) ``` Currently UDF classes like PythonUDF, HiveGenericUDF are not defined in catalyst. This PR is to add a new trait to make it easier to identify user-defined functions. ## How was this patch tested? Unit test Author: Wang Gengliang <[email protected]> Closes apache#19064 from gengliangwang/UDFType.
1 parent 32fa0b8 commit 8fcbda9

File tree

5 files changed

+28
-8
lines changed

5 files changed

+28
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression {
635635
}
636636
}
637637
}
638+
639+
/**
640+
* Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
641+
* and Hive function wrappers.
642+
*/
643+
trait UserDefinedExpression

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ case class ScalaUDF(
4747
udfName: Option[String] = None,
4848
nullable: Boolean = true,
4949
udfDeterministic: Boolean = true)
50-
extends Expression with ImplicitCastInputTypes with NonSQLExpression {
50+
extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
5151

5252
override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
5353

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,11 @@ case class ScalaUDAF(
324324
udaf: UserDefinedAggregateFunction,
325325
mutableAggBufferOffset: Int = 0,
326326
inputAggBufferOffset: Int = 0)
327-
extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {
327+
extends ImperativeAggregate
328+
with NonSQLExpression
329+
with Logging
330+
with ImplicitCastInputTypes
331+
with UserDefinedExpression {
328332

329333
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
330334
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import org.apache.spark.api.python.PythonFunction
21-
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression}
2222
import org.apache.spark.sql.types.DataType
2323

2424
/**
@@ -29,7 +29,7 @@ case class PythonUDF(
2929
func: PythonFunction,
3030
dataType: DataType,
3131
children: Seq[Expression])
32-
extends Expression with Unevaluable with NonSQLExpression {
32+
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
3333

3434
override def toString: String = s"$name(${children.mkString(", ")})"
3535

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ import org.apache.spark.sql.types._
4242

4343
private[hive] case class HiveSimpleUDF(
4444
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
45-
extends Expression with HiveInspectors with CodegenFallback with Logging {
45+
extends Expression
46+
with HiveInspectors
47+
with CodegenFallback
48+
with Logging
49+
with UserDefinedExpression {
4650

4751
override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
4852

@@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
119123

120124
private[hive] case class HiveGenericUDF(
121125
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
122-
extends Expression with HiveInspectors with CodegenFallback with Logging {
126+
extends Expression
127+
with HiveInspectors
128+
with CodegenFallback
129+
with Logging
130+
with UserDefinedExpression {
123131

124132
override def nullable: Boolean = true
125133

@@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF(
191199
name: String,
192200
funcWrapper: HiveFunctionWrapper,
193201
children: Seq[Expression])
194-
extends Generator with HiveInspectors with CodegenFallback {
202+
extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression {
195203

196204
@transient
197205
protected lazy val function: GenericUDTF = {
@@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction(
303311
isUDAFBridgeRequired: Boolean = false,
304312
mutableAggBufferOffset: Int = 0,
305313
inputAggBufferOffset: Int = 0)
306-
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
314+
extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
315+
with HiveInspectors
316+
with UserDefinedExpression {
307317

308318
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
309319
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

0 commit comments

Comments
 (0)