Skip to content
This repository was archived by the owner on Sep 18, 2023. It is now read-only.

Commit 6661b7b

Browse files
authored
[NSE-931] Reuse partition vectors for arrow scan (#935)
* reuse partition vectors * remove extra setValueCount and fix read null value
1 parent 59dfede commit 6661b7b

File tree

3 files changed

+56
-6
lines changed

3 files changed

+56
-6
lines changed

arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/arrow/ArrowFileFormat.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,12 @@ import org.apache.spark.sql.SparkSession
4040
import org.apache.spark.sql.catalyst.InternalRow
4141
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
4242
import org.apache.spark.sql.execution.datasources.OutputWriter
43+
import org.apache.spark.sql.execution.datasources.v2.arrow.{SparkMemoryUtils, SparkVectorUtils}
4344
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils.UnsafeItr
44-
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkVectorUtils
4545
import org.apache.spark.sql.internal.SQLConf
4646
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
4747
import org.apache.spark.sql.types.StructType
4848
import org.apache.spark.sql.util.CaseInsensitiveStringMap
49-
import org.apache.spark.sql.vectorized.ColumnarBatch;
5049

5150
class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializable {
5251

@@ -175,11 +174,17 @@ class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializab
175174
factory.close()
176175
}))
177176

177+
val partitionVectors =
178+
ArrowUtils.loadPartitionColumns(batchSize, partitionSchema, file.partitionValues)
179+
180+
SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => {
181+
partitionVectors.foreach(_.close())
182+
})
183+
178184
val itr = itrList
179185
.toIterator
180186
.flatMap(itr => itr.asScala)
181-
.map(batch => ArrowUtils.loadBatch(batch, file.partitionValues, partitionSchema,
182-
requiredSchema))
187+
.map(batch => ArrowUtils.loadBatch(batch, requiredSchema, partitionVectors))
183188
new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
184189
}
185190
}

arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowPartitionReaderFactory.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
2626
import org.apache.arrow.dataset.scanner.ScanOptions
2727
import org.apache.arrow.vector.types.pojo.Schema
2828

29+
import org.apache.spark.TaskContext
2930
import org.apache.spark.broadcast.Broadcast
3031
import org.apache.spark.sql.catalyst.InternalRow
3132
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
3233
import org.apache.spark.sql.execution.datasources.PartitionedFile
3334
import org.apache.spark.sql.execution.datasources.v2.FilePartitionReaderFactory
35+
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
3436
import org.apache.spark.sql.internal.SQLConf
3537
import org.apache.spark.sql.sources.Filter
3638
import org.apache.spark.sql.types.StructType
@@ -99,11 +101,17 @@ case class ArrowPartitionReaderFactory(
99101
val vsrItrList = taskList
100102
.map(task => task.execute())
101103

104+
val partitionVectors = ArrowUtils.loadPartitionColumns(
105+
batchSize, readPartitionSchema, partitionedFile.partitionValues)
106+
107+
SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit]((_: TaskContext) => {
108+
partitionVectors.foreach(_.close())
109+
})
110+
102111
val batchItr = vsrItrList
103112
.toIterator
104113
.flatMap(itr => itr.asScala)
105-
.map(batch => ArrowUtils.loadBatch(batch, partitionedFile.partitionValues,
106-
readPartitionSchema, readDataSchema))
114+
.map(batch => ArrowUtils.loadBatch(batch, readDataSchema, partitionVectors))
107115

108116
new PartitionReader[ColumnarBatch] {
109117
val holder = new ColumnarBatchRetainer()

arrow-data-source/standard/src/main/scala/com/intel/oap/spark/sql/execution/datasources/v2/arrow/ArrowUtils.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,43 @@ object ArrowUtils {
8888
SparkSchemaUtils.toArrowSchema(t, SparkSchemaUtils.getLocalTimezoneID())
8989
}
9090

91+
def loadPartitionColumns(
92+
rowCount: Int,
93+
partitionSchema: StructType,
94+
partitionValues: InternalRow): Array[ArrowWritableColumnVector] = {
95+
val partitionColumns = ArrowWritableColumnVector.allocateColumns(rowCount, partitionSchema)
96+
(0 until partitionColumns.length).foreach(i => {
97+
ArrowColumnVectorUtils.populate(partitionColumns(i), partitionValues, i)
98+
partitionColumns(i).setValueCount(rowCount)
99+
partitionColumns(i).setIsConstant()
100+
})
101+
partitionColumns
102+
}
103+
104+
def loadBatch(
105+
input: ArrowRecordBatch,
106+
dataSchema: StructType,
107+
partitionVectors: Array[ArrowWritableColumnVector]): ColumnarBatch = {
108+
val rowCount: Int = input.getLength
109+
110+
val vectors = try {
111+
ArrowWritableColumnVector.loadColumns(rowCount, toArrowSchema(dataSchema), input)
112+
} finally {
113+
input.close()
114+
}
115+
116+
val batch = new ColumnarBatch(
117+
vectors.map(_.asInstanceOf[ColumnVector]) ++
118+
partitionVectors
119+
.map { vector =>
120+
// The vector should call retain() whenever reuse it.
121+
vector.retain()
122+
vector.asInstanceOf[ColumnVector]
123+
},
124+
rowCount)
125+
batch
126+
}
127+
91128
def toArrowField(t: StructField): Field = {
92129
SparkSchemaUtils.toArrowField(
93130
t.name, t.dataType, t.nullable, SparkSchemaUtils.getLocalTimezoneID())

0 commit comments

Comments
 (0)