Skip to content

Commit fd990a9

Browse files
viiryahvanhovell
authored andcommitted
[SPARK-23873][SQL] Use accessors in interpreted LambdaVariable
## What changes were proposed in this pull request? Currently, interpreted execution of `LambdaVariable` just uses `InternalRow.get` to access element. We should use specified accessors if possible. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <[email protected]> Closes apache#20981 from viirya/SPARK-23873.
1 parent 0461482 commit fd990a9

File tree

5 files changed

+75
-23
lines changed

5 files changed

+75
-23
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.catalyst.expressions._
2121
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
22-
import org.apache.spark.sql.types.{DataType, Decimal, StructType}
22+
import org.apache.spark.sql.types._
2323
import org.apache.spark.unsafe.types.UTF8String
2424

2525
/**
@@ -119,4 +119,28 @@ object InternalRow {
119119
case v: MapData => v.copy()
120120
case _ => value
121121
}
122+
123+
/**
124+
* Returns an accessor for an `InternalRow` with given data type. The returned accessor
125+
* actually takes a `SpecializedGetters` input because it can be generalized to other classes
126+
* that implements `SpecializedGetters` (e.g., `ArrayData`) too.
127+
*/
128+
def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
129+
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
130+
case ByteType => (input, ordinal) => input.getByte(ordinal)
131+
case ShortType => (input, ordinal) => input.getShort(ordinal)
132+
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
133+
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
134+
case FloatType => (input, ordinal) => input.getFloat(ordinal)
135+
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
136+
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
137+
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
138+
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
139+
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
140+
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
141+
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
142+
case _: MapType => (input, ordinal) => input.getMap(ordinal)
143+
case u: UserDefinedType[_] => getAccessor(u.sqlType)
144+
case _ => (input, ordinal) => input.get(ordinal, dataType)
145+
}
122146
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,14 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
3333

3434
override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"
3535

36+
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
37+
3638
// Use special getter for primitive types (for UnsafeRow)
3739
override def eval(input: InternalRow): Any = {
38-
if (input.isNullAt(ordinal)) {
40+
if (nullable && input.isNullAt(ordinal)) {
3941
null
4042
} else {
41-
dataType match {
42-
case BooleanType => input.getBoolean(ordinal)
43-
case ByteType => input.getByte(ordinal)
44-
case ShortType => input.getShort(ordinal)
45-
case IntegerType | DateType => input.getInt(ordinal)
46-
case LongType | TimestampType => input.getLong(ordinal)
47-
case FloatType => input.getFloat(ordinal)
48-
case DoubleType => input.getDouble(ordinal)
49-
case StringType => input.getUTF8String(ordinal)
50-
case BinaryType => input.getBinary(ordinal)
51-
case CalendarIntervalType => input.getInterval(ordinal)
52-
case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale)
53-
case t: StructType => input.getStruct(ordinal, t.size)
54-
case _: ArrayType => input.getArray(ordinal)
55-
case _: MapType => input.getMap(ordinal)
56-
case _ => input.get(ordinal, dataType)
57-
}
43+
accessor(input, ordinal)
5844
}
5945
}
6046

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,17 @@ case class LambdaVariable(
560560
dataType: DataType,
561561
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
562562

563+
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
564+
563565
// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
564566
override def eval(input: InternalRow): Any = {
565567
assert(input.numFields == 1,
566568
"The input row of interpreted LambdaVariable should have only 1 field.")
567-
input.get(0, dataType)
569+
if (nullable && input.isNullAt(0)) {
570+
null
571+
} else {
572+
accessor(input, 0)
573+
}
568574
}
569575

570576
override def genCode(ctx: CodegenContext): ExprCode = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
7070
* Check the equality between result of expression and expected value, it will handle
7171
* Array[Byte], Spread[Double], MapData and Row.
7272
*/
73-
protected def checkResult(result: Any, expected: Any, dataType: DataType): Boolean = {
73+
protected def checkResult(result: Any, expected: Any, exprDataType: DataType): Boolean = {
74+
val dataType = UserDefinedType.sqlType(exprDataType)
75+
7476
(result, expected) match {
7577
case (result: Array[Byte], expected: Array[Byte]) =>
7678
java.util.Arrays.equals(result, expected)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ import java.sql.{Date, Timestamp}
2121

2222
import scala.collection.JavaConverters._
2323
import scala.reflect.ClassTag
24+
import scala.util.Random
2425

2526
import org.apache.spark.{SparkConf, SparkFunSuite}
2627
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
27-
import org.apache.spark.sql.Row
28+
import org.apache.spark.sql.{RandomDataGenerator, Row}
2829
import org.apache.spark.sql.catalyst.InternalRow
2930
import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
30-
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
31+
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, ExpressionEncoder, RowEncoder}
3132
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3233
import org.apache.spark.sql.catalyst.expressions.objects._
3334
import org.apache.spark.sql.catalyst.util._
@@ -381,6 +382,39 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
381382
checkEvaluation(decodeUsingSerializer, null, InternalRow.fromSeq(Seq(null)))
382383
}
383384
}
385+
386+
test("LambdaVariable should support interpreted execution") {
387+
def genSchema(dt: DataType): Seq[StructType] = {
388+
Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
389+
StructType(StructField("col_1", dt, nullable = true) :: Nil))
390+
}
391+
392+
val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
393+
DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
394+
CalendarIntervalType, new ExamplePointUDT())
395+
val arrayTypes = elementTypes.flatMap { elementType =>
396+
Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
397+
}
398+
val mapTypes = elementTypes.flatMap { elementType =>
399+
Seq(MapType(elementType, elementType, false), MapType(elementType, elementType, true))
400+
}
401+
val structTypes = elementTypes.flatMap { elementType =>
402+
Seq(StructType(StructField("col1", elementType, false) :: Nil),
403+
StructType(StructField("col1", elementType, true) :: Nil))
404+
}
405+
406+
val testTypes = elementTypes ++ arrayTypes ++ mapTypes ++ structTypes
407+
val random = new Random(100)
408+
testTypes.foreach { dt =>
409+
genSchema(dt).map { schema =>
410+
val row = RandomDataGenerator.randomRow(random, schema)
411+
val rowConverter = RowEncoder(schema)
412+
val internalRow = rowConverter.toRow(row)
413+
val lambda = LambdaVariable("dummy", "dummuIsNull", schema(0).dataType, schema(0).nullable)
414+
checkEvaluationWithoutCodegen(lambda, internalRow.get(0, schema(0).dataType), internalRow)
415+
}
416+
}
417+
}
384418
}
385419

386420
class TestBean extends Serializable {

0 commit comments

Comments
 (0)