Skip to content

Commit dd59a4b

Browse files
dongjoon-hyuncloud-fan
authored andcommitted
[SPARK-22712][SQL] Use buildReaderWithPartitionValues in native OrcFileFormat
## What changes were proposed in this pull request? To support vectorization in native OrcFileFormat later, we need to use `buildReaderWithPartitionValues` instead of `buildReader` like ParquetFileFormat. This PR replaces `buildReader` with `buildReaderWithPartitionValues`. ## How was this patch tested? Pass the Jenkins with the existing test cases. Author: Dongjoon Hyun <[email protected]> Closes #19907 from dongjoon-hyun/SPARK-ORC-BUILD-READER.
1 parent beb717f commit dd59a4b

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.AnalysisException
3636
import org.apache.spark.sql.SparkSession
3737
import org.apache.spark.sql.catalyst.InternalRow
3838
import org.apache.spark.sql.catalyst.expressions._
39+
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
3940
import org.apache.spark.sql.execution.datasources._
4041
import org.apache.spark.sql.sources._
4142
import org.apache.spark.sql.types._
@@ -124,7 +125,7 @@ class OrcFileFormat
124125
true
125126
}
126127

127-
override def buildReader(
128+
override def buildReaderWithPartitionValues(
128129
sparkSession: SparkSession,
129130
dataSchema: StructType,
130131
partitionSchema: StructType,
@@ -167,9 +168,17 @@ class OrcFileFormat
167168
val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
168169
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => iter.close()))
169170

170-
val unsafeProjection = UnsafeProjection.create(requiredSchema)
171+
val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes
172+
val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
171173
val deserializer = new OrcDeserializer(dataSchema, requiredSchema, requestedColIds)
172-
iter.map(value => unsafeProjection(deserializer.deserialize(value)))
174+
175+
if (partitionSchema.length == 0) {
176+
iter.map(value => unsafeProjection(deserializer.deserialize(value)))
177+
} else {
178+
val joinedRow = new JoinedRow()
179+
iter.map(value =>
180+
unsafeProjection(joinedRow(deserializer.deserialize(value), file.partitionValues)))
181+
}
173182
}
174183
}
175184
}

0 commit comments

Comments
 (0)