Skip to content

Commit 30ffb53

Browse files
viiryahvanhovell
authored andcommitted
[SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData
## What changes were proposed in this pull request? We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <[email protected]> Closes apache#20984 from viirya/SPARK-23875.
1 parent 05ae747 commit 30ffb53

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ case class MapObjects private(
708708
}
709709
}
710710
case ArrayType(et, _) =>
711-
_.asInstanceOf[ArrayData].array
711+
_.asInstanceOf[ArrayData].toSeq[Any](et)
712712
}
713713

714714
private lazy val mapElements: Seq[_] => Any = customCollectionCls match {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import scala.reflect.ClassTag
2121

22+
import org.apache.spark.sql.catalyst.InternalRow
2223
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
23-
import org.apache.spark.sql.types.DataType
24+
import org.apache.spark.sql.types._
2425

2526
object ArrayData {
2627
def toArrayData(input: Any): ArrayData = input match {
@@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
4243

4344
def array: Array[Any]
4445

46+
def toSeq[T](dataType: DataType): IndexedSeq[T] =
47+
new ArrayDataIndexedSeq[T](this, dataType)
48+
4549
def setNullAt(i: Int): Unit
4650

4751
def update(i: Int, value: Any): Unit
@@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
164168
}
165169
}
166170
}
171+
172+
/**
173+
* Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData`
174+
* is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`,
175+
* instead of `IndexedSeq[Int]`, in order to keep the null elements.
176+
*/
177+
class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] {
178+
179+
private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType)
180+
181+
override def apply(idx: Int): T =
182+
if (0 <= idx && idx < arrayData.numElements()) {
183+
if (arrayData.isNullAt(idx)) {
184+
null.asInstanceOf[T]
185+
} else {
186+
accessor(arrayData, idx).asInstanceOf[T]
187+
}
188+
} else {
189+
throw new IndexOutOfBoundsException(
190+
s"Index $idx must be between 0 and the length of the ArrayData.")
191+
}
192+
193+
override def length: Int = arrayData.numElements()
194+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.util
19+
20+
import scala.util.Random
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.RandomDataGenerator
24+
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
25+
import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection}
26+
import org.apache.spark.sql.types._
27+
28+
class ArrayDataIndexedSeqSuite extends SparkFunSuite {
29+
private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = {
30+
assert(arrayData.numElements == array.length)
31+
array.zipWithIndex.map { case (e, i) =>
32+
if (e != null) {
33+
elementDt match {
34+
// For NaN, etc.
35+
case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e))
36+
case _ => assert(arrayData.get(i, elementDt) === e)
37+
}
38+
} else {
39+
assert(arrayData.isNullAt(i))
40+
}
41+
}
42+
43+
val seq = arrayData.toSeq[Any](elementDt)
44+
array.zipWithIndex.map { case (e, i) =>
45+
if (e != null) {
46+
elementDt match {
47+
// For Nan, etc.
48+
case FloatType | DoubleType => assert(seq(i).equals(e))
49+
case _ => assert(seq(i) === e)
50+
}
51+
} else {
52+
assert(seq(i) == null)
53+
}
54+
}
55+
56+
intercept[IndexOutOfBoundsException] {
57+
seq(-1)
58+
}.getMessage().contains("must be between 0 and the length of the ArrayData.")
59+
60+
intercept[IndexOutOfBoundsException] {
61+
seq(seq.length)
62+
}.getMessage().contains("must be between 0 and the length of the ArrayData.")
63+
}
64+
65+
private def testArrayData(): Unit = {
66+
val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType,
67+
DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType,
68+
CalendarIntervalType, new ExamplePointUDT())
69+
val arrayTypes = elementTypes.flatMap { elementType =>
70+
Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true))
71+
}
72+
val random = new Random(100)
73+
arrayTypes.foreach { dt =>
74+
val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil)
75+
val row = RandomDataGenerator.randomRow(random, schema)
76+
val rowConverter = RowEncoder(schema)
77+
val internalRow = rowConverter.toRow(row)
78+
79+
val unsafeRowConverter = UnsafeProjection.create(schema)
80+
val safeRowConverter = FromUnsafeProjection(schema)
81+
82+
val unsafeRow = unsafeRowConverter(internalRow)
83+
val safeRow = safeRowConverter(unsafeRow)
84+
85+
val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData]
86+
val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
87+
88+
val elementType = dt.elementType
89+
test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) {
90+
compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType))
91+
}
92+
93+
test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) {
94+
compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType))
95+
}
96+
}
97+
}
98+
99+
testArrayData()
100+
}

0 commit comments

Comments
 (0)