Skip to content

Commit fb64523

Browse files
kiszkRobert Kruszewski
authored andcommitted
[SPARK-23582][SQL] StaticInvoke should support interpreted execution
## What changes were proposed in this pull request? This pr added interpreted execution for `StaticInvoke`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Kazuaki Ishizaki <[email protected]> Closes apache#20753 from kiszk/SPARK-23582.
1 parent 63ea877 commit fb64523

File tree

2 files changed

+77
-3
lines changed

2 files changed

+77
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
3535
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
3636
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
3737
import org.apache.spark.sql.types._
38+
import org.apache.spark.util.Utils
3839

3940
/**
4041
* Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
@@ -217,12 +218,21 @@ case class StaticInvoke(
217218
returnNullable: Boolean = true) extends InvokeLike {
218219

219220
val objectName = staticObject.getName.stripSuffix("$")
221+
val cls = if (staticObject.getName == objectName) {
222+
staticObject
223+
} else {
224+
Utils.classForName(objectName)
225+
}
220226

221227
override def nullable: Boolean = needNullCheck || returnNullable
222228
override def children: Seq[Expression] = arguments
223229

224-
override def eval(input: InternalRow): Any =
225-
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
230+
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
231+
@transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*)
232+
233+
override def eval(input: InternalRow): Any = {
234+
invoke(null, method, arguments, input, dataType)
235+
}
226236

227237
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
228238
val javaType = CodeGenerator.javaType(dataType)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.sql.{Date, Timestamp}
21+
2022
import scala.collection.JavaConverters._
2123
import scala.reflect.ClassTag
2224

@@ -28,9 +30,11 @@ import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
2830
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2931
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3032
import org.apache.spark.sql.catalyst.expressions.objects._
31-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
33+
import org.apache.spark.sql.catalyst.util._
34+
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{SQLDate, SQLTimestamp}
3235
import org.apache.spark.sql.internal.SQLConf
3336
import org.apache.spark.sql.types._
37+
import org.apache.spark.unsafe.types.UTF8String
3438

3539
class InvokeTargetClass extends Serializable {
3640
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
@@ -93,6 +97,66 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
9397
UnsafeProjection) // TODO(hvanhovell) revert this when SPARK-23587 is fixed
9498
}
9599

100+
test("SPARK-23582: StaticInvoke should support interpreted execution") {
101+
Seq((classOf[java.lang.Boolean], "true", true),
102+
(classOf[java.lang.Byte], "1", 1.toByte),
103+
(classOf[java.lang.Short], "257", 257.toShort),
104+
(classOf[java.lang.Integer], "12345", 12345),
105+
(classOf[java.lang.Long], "12345678", 12345678.toLong),
106+
(classOf[java.lang.Float], "12.34", 12.34.toFloat),
107+
(classOf[java.lang.Double], "1.2345678", 1.2345678)
108+
).foreach { case (cls, arg, expected) =>
109+
checkObjectExprEvaluation(StaticInvoke(cls, ObjectType(cls), "valueOf",
110+
Seq(BoundReference(0, ObjectType(classOf[java.lang.String]), true))),
111+
expected, InternalRow.fromSeq(Seq(arg)))
112+
}
113+
114+
// Return null when null argument is passed with propagateNull = true
115+
val stringCls = classOf[java.lang.String]
116+
checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf",
117+
Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = true),
118+
null, InternalRow.fromSeq(Seq(null)))
119+
checkObjectExprEvaluation(StaticInvoke(stringCls, ObjectType(stringCls), "valueOf",
120+
Seq(BoundReference(0, ObjectType(classOf[Object]), true)), propagateNull = false),
121+
"null", InternalRow.fromSeq(Seq(null)))
122+
123+
// test no argument
124+
val clCls = classOf[java.lang.ClassLoader]
125+
checkObjectExprEvaluation(StaticInvoke(clCls, ObjectType(clCls), "getSystemClassLoader", Nil),
126+
ClassLoader.getSystemClassLoader, InternalRow.empty)
127+
// test more than one argument
128+
val intCls = classOf[java.lang.Integer]
129+
checkObjectExprEvaluation(StaticInvoke(intCls, ObjectType(intCls), "compare",
130+
Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, false))),
131+
0, InternalRow.fromSeq(Seq(7, 7)))
132+
133+
Seq((DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", ObjectType(classOf[Timestamp]),
134+
new Timestamp(77777), DateTimeUtils.fromJavaTimestamp(new Timestamp(77777))),
135+
(DateTimeUtils.getClass, DateType, "fromJavaDate", ObjectType(classOf[Date]),
136+
new Date(88888888), DateTimeUtils.fromJavaDate(new Date(88888888))),
137+
(classOf[UTF8String], StringType, "fromString", ObjectType(classOf[String]),
138+
"abc", UTF8String.fromString("abc")),
139+
(Decimal.getClass, DecimalType(38, 0), "fromDecimal", ObjectType(classOf[Any]),
140+
BigInt(88888888), Decimal.fromDecimal(BigInt(88888888))),
141+
(Decimal.getClass, DecimalType.SYSTEM_DEFAULT,
142+
"apply", ObjectType(classOf[java.math.BigInteger]),
143+
new java.math.BigInteger("88888888"), Decimal.apply(new java.math.BigInteger("88888888"))),
144+
(classOf[ArrayData], ArrayType(IntegerType), "toArrayData", ObjectType(classOf[Any]),
145+
Array[Int](1, 2, 3), ArrayData.toArrayData(Array[Int](1, 2, 3))),
146+
(classOf[UnsafeArrayData], ArrayType(IntegerType, false),
147+
"fromPrimitiveArray", ObjectType(classOf[Array[Int]]),
148+
Array[Int](1, 2, 3), UnsafeArrayData.fromPrimitiveArray(Array[Int](1, 2, 3))),
149+
(DateTimeUtils.getClass, ObjectType(classOf[Date]),
150+
"toJavaDate", ObjectType(classOf[SQLDate]), 77777, DateTimeUtils.toJavaDate(77777)),
151+
(DateTimeUtils.getClass, ObjectType(classOf[Timestamp]),
152+
"toJavaTimestamp", ObjectType(classOf[SQLTimestamp]),
153+
88888888.toLong, DateTimeUtils.toJavaTimestamp(88888888))
154+
).foreach { case (cls, dataType, methodName, argType, arg, expected) =>
155+
checkObjectExprEvaluation(StaticInvoke(cls, dataType, methodName,
156+
Seq(BoundReference(0, argType, true))), expected, InternalRow.fromSeq(Seq(arg)))
157+
}
158+
}
159+
96160
test("SPARK-23583: Invoke should support interpreted execution") {
97161
val targetObject = new InvokeTargetClass
98162
val funcClass = classOf[InvokeTargetClass]

0 commit comments

Comments
 (0)