Skip to content

Commit 0bcf7e4

Browse files
viiryaRobert Kruszewski
authored andcommitted
[SPARK-23587][SQL] Add interpreted execution for MapObjects expression
## What changes were proposed in this pull request? Add interpreted execution for `MapObjects` expression. ## How was this patch tested? Added unit test. Author: Liang-Chi Hsieh <[email protected]> Closes apache#20771 from viirya/SPARK-23587.
1 parent b72b848 commit 0bcf7e4

File tree

2 files changed

+165
-12
lines changed

2 files changed

+165
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
1919

2020
import java.lang.reflect.Modifier
2121

22+
import scala.collection.JavaConverters._
2223
import scala.collection.mutable.Builder
2324
import scala.language.existentials
2425
import scala.reflect.ClassTag
@@ -501,12 +502,22 @@ case class LambdaVariable(
501502
value: String,
502503
isNull: String,
503504
dataType: DataType,
504-
nullable: Boolean = true) extends LeafExpression
505-
with Unevaluable with NonSQLExpression {
505+
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
506+
507+
// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
508+
override def eval(input: InternalRow): Any = {
509+
assert(input.numFields == 1,
510+
"The input row of interpreted LambdaVariable should have only 1 field.")
511+
input.get(0, dataType)
512+
}
506513

507514
override def genCode(ctx: CodegenContext): ExprCode = {
508515
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
509516
}
517+
518+
// This won't be called as `genCode` is overrided, just overriding it to make
519+
// `LambdaVariable` non-abstract.
520+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
510521
}
511522

512523
/**
@@ -599,8 +610,92 @@ case class MapObjects private(
599610

600611
override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
601612

602-
override def eval(input: InternalRow): Any =
603-
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
613+
// The data with UserDefinedType are actually stored with the data type of its sqlType.
614+
// When we want to apply MapObjects on it, we have to use it.
615+
lazy private val inputDataType = inputData.dataType match {
616+
case u: UserDefinedType[_] => u.sqlType
617+
case _ => inputData.dataType
618+
}
619+
620+
private def executeFuncOnCollection(inputCollection: Seq[_]): Iterator[_] = {
621+
val row = new GenericInternalRow(1)
622+
inputCollection.toIterator.map { element =>
623+
row.update(0, element)
624+
lambdaFunction.eval(row)
625+
}
626+
}
627+
628+
private lazy val convertToSeq: Any => Seq[_] = inputDataType match {
629+
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
630+
_.asInstanceOf[Seq[_]]
631+
case ObjectType(cls) if cls.isArray =>
632+
_.asInstanceOf[Array[_]].toSeq
633+
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
634+
_.asInstanceOf[java.util.List[_]].asScala
635+
case ObjectType(cls) if cls == classOf[Object] =>
636+
(inputCollection) => {
637+
if (inputCollection.getClass.isArray) {
638+
inputCollection.asInstanceOf[Array[_]].toSeq
639+
} else {
640+
inputCollection.asInstanceOf[Seq[_]]
641+
}
642+
}
643+
case ArrayType(et, _) =>
644+
_.asInstanceOf[ArrayData].array
645+
}
646+
647+
private lazy val mapElements: Seq[_] => Any = customCollectionCls match {
648+
case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
649+
// Scala sequence
650+
executeFuncOnCollection(_).toSeq
651+
case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
652+
// Scala set
653+
executeFuncOnCollection(_).toSet
654+
case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
655+
// Java list
656+
if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
657+
cls == classOf[java.util.AbstractSequentialList[_]]) {
658+
// Specifying non concrete implementations of `java.util.List`
659+
executeFuncOnCollection(_).toSeq.asJava
660+
} else {
661+
val constructors = cls.getConstructors()
662+
val intParamConstructor = constructors.find { constructor =>
663+
constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int]
664+
}
665+
val noParamConstructor = constructors.find { constructor =>
666+
constructor.getParameterCount == 0
667+
}
668+
669+
val constructor = intParamConstructor.map { intConstructor =>
670+
(len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
671+
}.getOrElse {
672+
(_: Int) => noParamConstructor.get.newInstance()
673+
}
674+
675+
// Specifying concrete implementations of `java.util.List`
676+
(inputs) => {
677+
val results = executeFuncOnCollection(inputs)
678+
val builder = constructor(inputs.length).asInstanceOf[java.util.List[Any]]
679+
results.foreach(builder.add(_))
680+
builder
681+
}
682+
}
683+
case None =>
684+
// array
685+
x => new GenericArrayData(executeFuncOnCollection(x).toArray)
686+
case Some(cls) =>
687+
throw new RuntimeException(s"class `${cls.getName}` is not supported by `MapObjects` as " +
688+
"resulting collection.")
689+
}
690+
691+
override def eval(input: InternalRow): Any = {
692+
val inputCollection = inputData.eval(input)
693+
694+
if (inputCollection == null) {
695+
return null
696+
}
697+
mapElements(convertToSeq(inputCollection))
698+
}
604699

605700
override def dataType: DataType =
606701
customCollectionCls.map(ObjectType.apply).getOrElse(
@@ -647,13 +742,6 @@ case class MapObjects private(
647742
case _ => ""
648743
}
649744

650-
// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
651-
// When we want to apply MapObjects on it, we have to use it.
652-
val inputDataType = inputData.dataType match {
653-
case p: PythonUserDefinedType => p.sqlType
654-
case _ => inputData.dataType
655-
}
656-
657745
// `MapObjects` generates a while loop to traverse the elements of the input collection. We
658746
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
659747
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import scala.collection.JavaConverters._
2021
import scala.reflect.ClassTag
2122

2223
import org.apache.spark.{SparkConf, SparkFunSuite}
@@ -25,7 +26,7 @@ import org.apache.spark.sql.Row
2526
import org.apache.spark.sql.catalyst.InternalRow
2627
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2728
import org.apache.spark.sql.catalyst.expressions.objects._
28-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
29+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
2930
import org.apache.spark.sql.types._
3031

3132

@@ -135,6 +136,70 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
135136
}
136137
}
137138

139+
test("SPARK-23587: MapObjects should support interpreted execution") {
140+
def testMapObjects(collection: Any, collectionCls: Class[_], inputType: DataType): Unit = {
141+
val function = (lambda: Expression) => Add(lambda, Literal(1))
142+
val elementType = IntegerType
143+
val expected = Seq(2, 3, 4)
144+
145+
val inputObject = BoundReference(0, inputType, nullable = true)
146+
val optClass = Option(collectionCls)
147+
val mapObj = MapObjects(function, inputObject, elementType, true, optClass)
148+
val row = InternalRow.fromSeq(Seq(collection))
149+
val result = mapObj.eval(row)
150+
151+
collectionCls match {
152+
case null =>
153+
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
154+
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
155+
assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected)
156+
case s if classOf[Seq[_]].isAssignableFrom(s) =>
157+
assert(result.asInstanceOf[Seq[_]].toSeq == expected)
158+
case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
159+
assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
160+
}
161+
}
162+
163+
val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
164+
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
165+
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
166+
classOf[java.util.Stack[Int]], null)
167+
168+
val list = new java.util.ArrayList[Int]()
169+
list.add(1)
170+
list.add(2)
171+
list.add(3)
172+
val arrayData = new GenericArrayData(Array(1, 2, 3))
173+
val vector = new java.util.Vector[Int]()
174+
vector.add(1)
175+
vector.add(2)
176+
vector.add(3)
177+
val stack = new java.util.Stack[Int]()
178+
stack.add(1)
179+
stack.add(2)
180+
stack.add(3)
181+
182+
Seq(
183+
(Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
184+
(Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
185+
(Seq(1, 2, 3), ObjectType(classOf[Object])),
186+
(Array(1, 2, 3), ObjectType(classOf[Object])),
187+
(list, ObjectType(classOf[java.util.List[Int]])),
188+
(vector, ObjectType(classOf[java.util.Vector[Int]])),
189+
(stack, ObjectType(classOf[java.util.Stack[Int]])),
190+
(arrayData, ArrayType(IntegerType))
191+
).foreach { case (collection, inputType) =>
192+
customCollectionClasses.foreach(testMapObjects(collection, _, inputType))
193+
194+
// Unsupported custom collection class
195+
val errMsg = intercept[RuntimeException] {
196+
testMapObjects(collection, classOf[scala.collection.Map[Int, Int]], inputType)
197+
}.getMessage()
198+
assert(errMsg.contains("`scala.collection.Map` is not supported by `MapObjects` " +
199+
"as resulting collection."))
200+
}
201+
}
202+
138203
test("SPARK-23592: DecodeUsingSerializer should support interpreted execution") {
139204
val cls = classOf[java.lang.Integer]
140205
val inputObject = BoundReference(0, ObjectType(classOf[Array[Byte]]), nullable = true)

0 commit comments

Comments
 (0)