Skip to content

Commit 4177292

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-27435][SQL] Support schema pruning in ORC V2
## What changes were proposed in this pull request? Currently, the optimization rule `SchemaPruning` only works for Parquet/Orc V1. We should have the same optimization in ORC V2. ## How was this patch tested? Unit test Closes apache#24338 from gengliangwang/schemaPruningForV2. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 0745333 commit 4177292

File tree

4 files changed

+139
-23
lines changed

4 files changed

+139
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717

1818
package org.apache.spark.sql.execution.datasources
1919

20+
import org.apache.spark.sql.AnalysisException
2021
import org.apache.spark.sql.catalyst.expressions._
2122
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
22-
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
23+
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
2324
import org.apache.spark.sql.catalyst.rules.Rule
2425
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
2526
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
27+
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable}
28+
import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
2629
import org.apache.spark.sql.internal.SQLConf
2730
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
2831

@@ -48,7 +51,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
4851
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _))
4952
if canPruneRelation(hadoopFsRelation) =>
5053
val (normalizedProjects, normalizedFilters) =
51-
normalizeAttributeRefNames(l, projects, filters)
54+
normalizeAttributeRefNames(l.output, projects, filters)
5255
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
5356

5457
// If requestedRootFields includes a nested field, continue. Otherwise,
@@ -76,6 +79,43 @@ object SchemaPruning extends Rule[LogicalPlan] {
7679
} else {
7780
op
7881
}
82+
83+
case op @ PhysicalOperation(projects, filters,
84+
d @ DataSourceV2Relation(table: FileTable, output, _)) if canPruneTable(table) =>
85+
val (normalizedProjects, normalizedFilters) =
86+
normalizeAttributeRefNames(output, projects, filters)
87+
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
88+
89+
// If requestedRootFields includes a nested field, continue. Otherwise,
90+
// return op
91+
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
92+
val dataSchema = table.dataSchema
93+
val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields)
94+
95+
// If the data schema is different from the pruned data schema, continue. Otherwise,
96+
// return op. We effect this comparison by counting the number of "leaf" fields in
97+
// each schemata, assuming the fields in prunedDataSchema are a subset of the fields
98+
// in dataSchema.
99+
if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
100+
val prunedFileTable = table match {
101+
case o: OrcTable => o.copy(userSpecifiedSchema = Some(prunedDataSchema))
102+
case _ =>
103+
val message = s"${table.formatName} data source doesn't support schema pruning."
104+
throw new AnalysisException(message)
105+
}
106+
107+
108+
val prunedRelationV2 = buildPrunedRelationV2(d, prunedFileTable)
109+
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
110+
111+
buildNewProjection(normalizedProjects, normalizedFilters, prunedRelationV2,
112+
projectionOverSchema)
113+
} else {
114+
op
115+
}
116+
} else {
117+
op
118+
}
79119
}
80120

81121
/**
@@ -85,16 +125,22 @@ object SchemaPruning extends Rule[LogicalPlan] {
85125
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
86126
fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
87127

128+
/**
129+
* Checks to see if the given [[FileTable]] can be pruned. Currently we support ORC v2.
130+
*/
131+
private def canPruneTable(table: FileTable) =
132+
table.isInstanceOf[OrcTable]
133+
88134
/**
89135
* Normalizes the names of the attribute references in the given projects and filters to reflect
90136
* the names in the given logical relation. This makes it possible to compare attributes and
91137
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
92138
*/
93139
private def normalizeAttributeRefNames(
94-
logicalRelation: LogicalRelation,
140+
output: Seq[AttributeReference],
95141
projects: Seq[NamedExpression],
96142
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
97-
val normalizedAttNameMap = logicalRelation.output.map(att => (att.exprId, att.name)).toMap
143+
val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
98144
val normalizedProjects = projects.map(_.transform {
99145
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
100146
att.withName(normalizedAttNameMap(att.exprId))
@@ -107,11 +153,13 @@ object SchemaPruning extends Rule[LogicalPlan] {
107153
}
108154

109155
/**
110-
* Builds the new output [[Project]] Spark SQL operator that has the pruned output relation.
156+
* Builds the new output [[Project]] Spark SQL operator that has the `leafNode`.
111157
*/
112158
private def buildNewProjection(
113-
projects: Seq[NamedExpression], filters: Seq[Expression], prunedRelation: LogicalRelation,
114-
projectionOverSchema: ProjectionOverSchema) = {
159+
projects: Seq[NamedExpression],
160+
filters: Seq[Expression],
161+
leafNode: LeafNode,
162+
projectionOverSchema: ProjectionOverSchema): Project = {
115163
// Construct a new target for our projection by rewriting and
116164
// including the original filters where available
117165
val projectionChild =
@@ -120,9 +168,9 @@ object SchemaPruning extends Rule[LogicalPlan] {
120168
case projectionOverSchema(expr) => expr
121169
})
122170
val newFilterCondition = projectedFilters.reduce(And)
123-
Filter(newFilterCondition, prunedRelation)
171+
Filter(newFilterCondition, leafNode)
124172
} else {
125-
prunedRelation
173+
leafNode
126174
}
127175

128176
// Construct the new projections of our Project by
@@ -145,20 +193,36 @@ object SchemaPruning extends Rule[LogicalPlan] {
145193
private def buildPrunedRelation(
146194
outputRelation: LogicalRelation,
147195
prunedBaseRelation: HadoopFsRelation) = {
196+
val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
197+
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
198+
}
199+
200+
/**
201+
* Builds a pruned data source V2 relation from the output of the relation and the schema
202+
* of the pruned [[FileTable]].
203+
*/
204+
private def buildPrunedRelationV2(
205+
outputRelation: DataSourceV2Relation,
206+
prunedFileTable: FileTable) = {
207+
val prunedOutput = getPrunedOutput(outputRelation.output, prunedFileTable.schema)
208+
outputRelation.copy(table = prunedFileTable, output = prunedOutput)
209+
}
210+
211+
// Prune the given output to make it consistent with `requiredSchema`.
212+
private def getPrunedOutput(
213+
output: Seq[AttributeReference],
214+
requiredSchema: StructType): Seq[AttributeReference] = {
148215
// We need to replace the expression ids of the pruned relation output attributes
149216
// with the expression ids of the original relation output attributes so that
150217
// references to the original relation's output are not broken
151-
val outputIdMap = outputRelation.output.map(att => (att.name, att.exprId)).toMap
152-
val prunedRelationOutput =
153-
prunedBaseRelation
154-
.schema
155-
.toAttributes
156-
.map {
157-
case att if outputIdMap.contains(att.name) =>
158-
att.withExprId(outputIdMap(att.name))
159-
case att => att
160-
}
161-
outputRelation.copy(relation = prunedBaseRelation, output = prunedRelationOutput)
218+
val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
219+
requiredSchema
220+
.toAttributes
221+
.map {
222+
case att if outputIdMap.contains(att.name) =>
223+
att.withExprId(outputIdMap(att.name))
224+
case att => att
225+
}
162226
}
163227

164228
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ abstract class SchemaPruningSuite
407407
}
408408
}
409409

410-
private val schemaEquality = new Equality[StructType] {
410+
protected val schemaEquality = new Equality[StructType] {
411411
override def areEqual(a: StructType, b: Any): Boolean =
412412
b match {
413413
case otherType: StructType => a.sameType(otherType)
@@ -422,7 +422,7 @@ abstract class SchemaPruningSuite
422422
df.collect()
423423
}
424424

425-
private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
425+
protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
426426
val fileSourceScanSchemata =
427427
df.queryExecution.executedPlan.collect {
428428
case scan: FileSourceScanExec => scan.requiredSchema

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala renamed to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.SparkConf
2121
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
2222
import org.apache.spark.sql.internal.SQLConf
2323

24-
class OrcSchemaPruningSuite extends SchemaPruningSuite {
24+
class OrcV1SchemaPruningSuite extends SchemaPruningSuite {
2525
override protected val dataSourceName: String = "orc"
2626
override protected val vectorizedReaderEnabledKey: String =
2727
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.datasources.orc
18+
19+
import org.apache.spark.SparkConf
20+
import org.apache.spark.sql.DataFrame
21+
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
22+
import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
23+
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
24+
import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan
25+
import org.apache.spark.sql.internal.SQLConf
26+
27+
class OrcV2SchemaPruningSuite extends SchemaPruningSuite {
28+
override protected val dataSourceName: String = "orc"
29+
override protected val vectorizedReaderEnabledKey: String =
30+
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
31+
32+
override protected def sparkConf: SparkConf =
33+
super
34+
.sparkConf
35+
.set(SQLConf.USE_V1_SOURCE_READER_LIST, "")
36+
37+
override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
38+
val fileSourceScanSchemata =
39+
df.queryExecution.executedPlan.collect {
40+
case BatchScanExec(_, scan: OrcScan) => scan.readDataSchema
41+
}
42+
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
43+
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
44+
s"but expected $expectedSchemaCatalogStrings")
45+
fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach {
46+
case (scanSchema, expectedScanSchemaCatalogString) =>
47+
val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString)
48+
implicit val equality = schemaEquality
49+
assert(scanSchema === expectedScanSchema)
50+
}
51+
}
52+
}

0 commit comments

Comments
 (0)