Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 3b049ab

Browse files
Feng Liugatorsmile
authored andcommitted
[SPARK-22003][SQL] support array column in vectorized reader with UDF
## What changes were proposed in this pull request? The UDF needs to deserialize the `UnsafeRow`. When the column type is Array, the `get` method from the `ColumnVector`, which is used by the vectorized reader, is called, but this method is not implemented. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Feng Liu <[email protected]> Closes apache#19230 from liufengdb/fix_array_open.
1 parent 894a756 commit 3b049ab

File tree

2 files changed

+242
-62
lines changed

2 files changed

+242
-62
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -100,72 +100,16 @@ public ArrayData copy() {
100100
public Object[] array() {
101101
DataType dt = data.dataType();
102102
Object[] list = new Object[length];
103-
104-
if (dt instanceof BooleanType) {
105-
for (int i = 0; i < length; i++) {
106-
if (!data.isNullAt(offset + i)) {
107-
list[i] = data.getBoolean(offset + i);
108-
}
109-
}
110-
} else if (dt instanceof ByteType) {
111-
for (int i = 0; i < length; i++) {
112-
if (!data.isNullAt(offset + i)) {
113-
list[i] = data.getByte(offset + i);
114-
}
115-
}
116-
} else if (dt instanceof ShortType) {
117-
for (int i = 0; i < length; i++) {
118-
if (!data.isNullAt(offset + i)) {
119-
list[i] = data.getShort(offset + i);
120-
}
121-
}
122-
} else if (dt instanceof IntegerType) {
123-
for (int i = 0; i < length; i++) {
124-
if (!data.isNullAt(offset + i)) {
125-
list[i] = data.getInt(offset + i);
126-
}
127-
}
128-
} else if (dt instanceof FloatType) {
129-
for (int i = 0; i < length; i++) {
130-
if (!data.isNullAt(offset + i)) {
131-
list[i] = data.getFloat(offset + i);
132-
}
133-
}
134-
} else if (dt instanceof DoubleType) {
103+
try {
135104
for (int i = 0; i < length; i++) {
136105
if (!data.isNullAt(offset + i)) {
137-
list[i] = data.getDouble(offset + i);
106+
list[i] = get(i, dt);
138107
}
139108
}
140-
} else if (dt instanceof LongType) {
141-
for (int i = 0; i < length; i++) {
142-
if (!data.isNullAt(offset + i)) {
143-
list[i] = data.getLong(offset + i);
144-
}
145-
}
146-
} else if (dt instanceof DecimalType) {
147-
DecimalType decType = (DecimalType)dt;
148-
for (int i = 0; i < length; i++) {
149-
if (!data.isNullAt(offset + i)) {
150-
list[i] = getDecimal(i, decType.precision(), decType.scale());
151-
}
152-
}
153-
} else if (dt instanceof StringType) {
154-
for (int i = 0; i < length; i++) {
155-
if (!data.isNullAt(offset + i)) {
156-
list[i] = getUTF8String(i).toString();
157-
}
158-
}
159-
} else if (dt instanceof CalendarIntervalType) {
160-
for (int i = 0; i < length; i++) {
161-
if (!data.isNullAt(offset + i)) {
162-
list[i] = getInterval(i);
163-
}
164-
}
165-
} else {
166-
throw new UnsupportedOperationException("Type " + dt);
109+
return list;
110+
} catch(Exception e) {
111+
throw new RuntimeException("Could not get the array", e);
167112
}
168-
return list;
169113
}
170114

171115
@Override
@@ -237,7 +181,42 @@ public MapData getMap(int ordinal) {
237181

238182
@Override
239183
public Object get(int ordinal, DataType dataType) {
240-
throw new UnsupportedOperationException();
184+
if (dataType instanceof BooleanType) {
185+
return getBoolean(ordinal);
186+
} else if (dataType instanceof ByteType) {
187+
return getByte(ordinal);
188+
} else if (dataType instanceof ShortType) {
189+
return getShort(ordinal);
190+
} else if (dataType instanceof IntegerType) {
191+
return getInt(ordinal);
192+
} else if (dataType instanceof LongType) {
193+
return getLong(ordinal);
194+
} else if (dataType instanceof FloatType) {
195+
return getFloat(ordinal);
196+
} else if (dataType instanceof DoubleType) {
197+
return getDouble(ordinal);
198+
} else if (dataType instanceof StringType) {
199+
return getUTF8String(ordinal);
200+
} else if (dataType instanceof BinaryType) {
201+
return getBinary(ordinal);
202+
} else if (dataType instanceof DecimalType) {
203+
DecimalType t = (DecimalType) dataType;
204+
return getDecimal(ordinal, t.precision(), t.scale());
205+
} else if (dataType instanceof DateType) {
206+
return getInt(ordinal);
207+
} else if (dataType instanceof TimestampType) {
208+
return getLong(ordinal);
209+
} else if (dataType instanceof ArrayType) {
210+
return getArray(ordinal);
211+
} else if (dataType instanceof StructType) {
212+
return getStruct(ordinal, ((StructType)dataType).fields().length);
213+
} else if (dataType instanceof MapType) {
214+
return getMap(ordinal);
215+
} else if (dataType instanceof CalendarIntervalType) {
216+
return getInterval(ordinal);
217+
} else {
218+
throw new UnsupportedOperationException("Datatype not supported " + dataType);
219+
}
241220
}
242221

243222
@Override
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.vectorized
19+
20+
import org.scalatest.BeforeAndAfterEach
21+
22+
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.catalyst.util.ArrayData
24+
import org.apache.spark.sql.types._
25+
import org.apache.spark.unsafe.types.UTF8String
26+
27+
class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
28+
29+
var testVector: WritableColumnVector = _
30+
31+
private def allocate(capacity: Int, dt: DataType): WritableColumnVector = {
32+
new OnHeapColumnVector(capacity, dt)
33+
}
34+
35+
override def afterEach(): Unit = {
36+
testVector.close()
37+
}
38+
39+
test("boolean") {
40+
testVector = allocate(10, BooleanType)
41+
(0 until 10).foreach { i =>
42+
testVector.appendBoolean(i % 2 == 0)
43+
}
44+
45+
val array = new ColumnVector.Array(testVector)
46+
47+
(0 until 10).foreach { i =>
48+
assert(array.get(i, BooleanType) === (i % 2 == 0))
49+
}
50+
}
51+
52+
test("byte") {
53+
testVector = allocate(10, ByteType)
54+
(0 until 10).foreach { i =>
55+
testVector.appendByte(i.toByte)
56+
}
57+
58+
val array = new ColumnVector.Array(testVector)
59+
60+
(0 until 10).foreach { i =>
61+
assert(array.get(i, ByteType) === (i.toByte))
62+
}
63+
}
64+
65+
test("short") {
66+
testVector = allocate(10, ShortType)
67+
(0 until 10).foreach { i =>
68+
testVector.appendShort(i.toShort)
69+
}
70+
71+
val array = new ColumnVector.Array(testVector)
72+
73+
(0 until 10).foreach { i =>
74+
assert(array.get(i, ShortType) === (i.toShort))
75+
}
76+
}
77+
78+
test("int") {
79+
testVector = allocate(10, IntegerType)
80+
(0 until 10).foreach { i =>
81+
testVector.appendInt(i)
82+
}
83+
84+
val array = new ColumnVector.Array(testVector)
85+
86+
(0 until 10).foreach { i =>
87+
assert(array.get(i, IntegerType) === i)
88+
}
89+
}
90+
91+
test("long") {
92+
testVector = allocate(10, LongType)
93+
(0 until 10).foreach { i =>
94+
testVector.appendLong(i)
95+
}
96+
97+
val array = new ColumnVector.Array(testVector)
98+
99+
(0 until 10).foreach { i =>
100+
assert(array.get(i, LongType) === i)
101+
}
102+
}
103+
104+
test("float") {
105+
testVector = allocate(10, FloatType)
106+
(0 until 10).foreach { i =>
107+
testVector.appendFloat(i.toFloat)
108+
}
109+
110+
val array = new ColumnVector.Array(testVector)
111+
112+
(0 until 10).foreach { i =>
113+
assert(array.get(i, FloatType) === i.toFloat)
114+
}
115+
}
116+
117+
test("double") {
118+
testVector = allocate(10, DoubleType)
119+
(0 until 10).foreach { i =>
120+
testVector.appendDouble(i.toDouble)
121+
}
122+
123+
val array = new ColumnVector.Array(testVector)
124+
125+
(0 until 10).foreach { i =>
126+
assert(array.get(i, DoubleType) === i.toDouble)
127+
}
128+
}
129+
130+
test("string") {
131+
testVector = allocate(10, StringType)
132+
(0 until 10).map { i =>
133+
val utf8 = s"str$i".getBytes("utf8")
134+
testVector.appendByteArray(utf8, 0, utf8.length)
135+
}
136+
137+
val array = new ColumnVector.Array(testVector)
138+
139+
(0 until 10).foreach { i =>
140+
assert(array.get(i, StringType) === UTF8String.fromString(s"str$i"))
141+
}
142+
}
143+
144+
test("binary") {
145+
testVector = allocate(10, BinaryType)
146+
(0 until 10).map { i =>
147+
val utf8 = s"str$i".getBytes("utf8")
148+
testVector.appendByteArray(utf8, 0, utf8.length)
149+
}
150+
151+
val array = new ColumnVector.Array(testVector)
152+
153+
(0 until 10).foreach { i =>
154+
val utf8 = s"str$i".getBytes("utf8")
155+
assert(array.get(i, BinaryType) === utf8)
156+
}
157+
}
158+
159+
test("array") {
160+
val arrayType = ArrayType(IntegerType, true)
161+
testVector = allocate(10, arrayType)
162+
163+
val data = testVector.arrayData()
164+
var i = 0
165+
while (i < 6) {
166+
data.putInt(i, i)
167+
i += 1
168+
}
169+
170+
// Populate it with arrays [0], [1, 2], [], [3, 4, 5]
171+
testVector.putArray(0, 0, 1)
172+
testVector.putArray(1, 1, 2)
173+
testVector.putArray(2, 3, 0)
174+
testVector.putArray(3, 3, 3)
175+
176+
val array = new ColumnVector.Array(testVector)
177+
178+
assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0))
179+
assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2))
180+
assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int])
181+
assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5))
182+
}
183+
184+
test("struct") {
185+
val schema = new StructType().add("int", IntegerType).add("double", DoubleType)
186+
testVector = allocate(10, schema)
187+
val c1 = testVector.getChildColumn(0)
188+
val c2 = testVector.getChildColumn(1)
189+
c1.putInt(0, 123)
190+
c2.putDouble(0, 3.45)
191+
c1.putInt(1, 456)
192+
c2.putDouble(1, 5.67)
193+
194+
val array = new ColumnVector.Array(testVector)
195+
196+
assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123)
197+
assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45)
198+
assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456)
199+
assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67)
200+
}
201+
}

0 commit comments

Comments
 (0)