Skip to content

Commit 0dd97f6

Browse files
maropuhvanhovell
authored andcommitted
[SPARK-23595][SQL] ValidateExternalType should support interpreted execution
## What changes were proposed in this pull request? This pr supported interpreted mode for `ValidateExternalType`. ## How was this patch tested? Added tests in `ObjectExpressionsSuite`. Author: Takeshi Yamamuro <[email protected]> Closes apache#20757 from maropu/SPARK-23595.
1 parent 074a7f9 commit 0dd97f6

File tree

4 files changed

+74
-8
lines changed

4 files changed

+74
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,19 @@ object ScalaReflection extends ScalaReflection {
846846
}
847847
}
848848

849+
def javaBoxedType(dt: DataType): Class[_] = dt match {
850+
case _: DecimalType => classOf[Decimal]
851+
case BinaryType => classOf[Array[Byte]]
852+
case StringType => classOf[UTF8String]
853+
case CalendarIntervalType => classOf[CalendarInterval]
854+
case _: StructType => classOf[InternalRow]
855+
case _: ArrayType => classOf[ArrayType]
856+
case _: MapType => classOf[MapType]
857+
case udt: UserDefinedType[_] => javaBoxedType(udt.sqlType)
858+
case ObjectType(cls) => cls
859+
case _ => ScalaReflection.typeBoxedJavaMapping.getOrElse(dt, classOf[java.lang.Object])
860+
}
861+
849862
def expressionJavaClasses(arguments: Seq[Expression]): Seq[Class[_]] = {
850863
if (arguments != Nil) {
851864
arguments.map(e => dataTypeJavaClass(e.dataType))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
2626
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.expressions.objects._
29-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
29+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils}
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.types.UTF8String
3232

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
3535
import org.apache.spark.sql.catalyst.expressions.codegen._
3636
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
3737
import org.apache.spark.sql.types._
38+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3839
import org.apache.spark.util.Utils
3940

4041
/**
@@ -1672,13 +1673,36 @@ case class ValidateExternalType(child: Expression, expected: DataType)
16721673

16731674
override def nullable: Boolean = child.nullable
16741675

1675-
override def dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
1676-
1677-
override def eval(input: InternalRow): Any =
1678-
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
1676+
override val dataType: DataType = RowEncoder.externalDataTypeForInput(expected)
16791677

16801678
private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"
16811679

1680+
private lazy val checkType: (Any) => Boolean = expected match {
1681+
case _: DecimalType =>
1682+
(value: Any) => {
1683+
value.isInstanceOf[java.math.BigDecimal] || value.isInstanceOf[scala.math.BigDecimal] ||
1684+
value.isInstanceOf[Decimal]
1685+
}
1686+
case _: ArrayType =>
1687+
(value: Any) => {
1688+
value.getClass.isArray || value.isInstanceOf[Seq[_]]
1689+
}
1690+
case _ =>
1691+
val dataTypeClazz = ScalaReflection.javaBoxedType(dataType)
1692+
(value: Any) => {
1693+
dataTypeClazz.isInstance(value)
1694+
}
1695+
}
1696+
1697+
override def eval(input: InternalRow): Any = {
1698+
val result = child.eval(input)
1699+
if (checkType(result)) {
1700+
result
1701+
} else {
1702+
throw new RuntimeException(s"${result.getClass.getName}$errMsg")
1703+
}
1704+
}
1705+
16821706
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
16831707
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
16841708
// because errMsgField is used only when the type doesn't match.
@@ -1691,7 +1715,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
16911715
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
16921716
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
16931717
case _: ArrayType =>
1694-
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
1718+
s"$obj.getClass().isArray() || $obj instanceof ${classOf[Seq[_]].getName}"
16951719
case _ =>
16961720
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
16971721
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
3737
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3838
import org.apache.spark.sql.internal.SQLConf
3939
import org.apache.spark.sql.types._
40-
import org.apache.spark.unsafe.types.UTF8String
40+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
4141

4242
class InvokeTargetClass extends Serializable {
4343
def filterInt(e: Any): Any = e.asInstanceOf[Int] > 0
@@ -296,7 +296,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
296296
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
297297
val getRowField = GetExternalRowField(inputObject, index = 0, fieldName = "c0")
298298
Seq((Row(1), 1), (Row(3), 3)).foreach { case (input, expected) =>
299-
checkEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
299+
checkObjectExprEvaluation(getRowField, expected, InternalRow.fromSeq(Seq(input)))
300300
}
301301

302302
// If an input row or a field are null, a runtime exception will be thrown
@@ -472,6 +472,35 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
472472
val deserializer = toMapExpr.copy(inputData = Literal.create(data))
473473
checkObjectExprEvaluation(deserializer, expected = data)
474474
}
475+
476+
test("SPARK-23595 ValidateExternalType should support interpreted execution") {
477+
val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
478+
Seq(
479+
(true, BooleanType),
480+
(2.toByte, ByteType),
481+
(5.toShort, ShortType),
482+
(23, IntegerType),
483+
(61L, LongType),
484+
(1.0f, FloatType),
485+
(10.0, DoubleType),
486+
("abcd".getBytes, BinaryType),
487+
("abcd", StringType),
488+
(BigDecimal.valueOf(10), DecimalType.IntDecimal),
489+
(CalendarInterval.fromString("interval 3 day"), CalendarIntervalType),
490+
(java.math.BigDecimal.valueOf(10), DecimalType.BigIntDecimal),
491+
(Array(3, 2, 1), ArrayType(IntegerType))
492+
).foreach { case (input, dt) =>
493+
val validateType = ValidateExternalType(
494+
GetExternalRowField(inputObject, index = 0, fieldName = "c0"), dt)
495+
checkObjectExprEvaluation(validateType, input, InternalRow.fromSeq(Seq(Row(input))))
496+
}
497+
498+
checkExceptionInExpression[RuntimeException](
499+
ValidateExternalType(
500+
GetExternalRowField(inputObject, index = 0, fieldName = "c0"), DoubleType),
501+
InternalRow.fromSeq(Seq(Row(1))),
502+
"java.lang.Integer is not a valid external type for schema of double")
503+
}
475504
}
476505

477506
class TestBean extends Serializable {

0 commit comments

Comments
 (0)