17
17
18
18
package org .apache .spark .sql .execution .datasources
19
19
20
+ import org .apache .spark .sql .AnalysisException
20
21
import org .apache .spark .sql .catalyst .expressions ._
21
22
import org .apache .spark .sql .catalyst .planning .PhysicalOperation
22
- import org .apache .spark .sql .catalyst .plans .logical .{Filter , LogicalPlan , Project }
23
+ import org .apache .spark .sql .catalyst .plans .logical .{Filter , LeafNode , LogicalPlan , Project }
23
24
import org .apache .spark .sql .catalyst .rules .Rule
24
25
import org .apache .spark .sql .execution .datasources .orc .OrcFileFormat
25
26
import org .apache .spark .sql .execution .datasources .parquet .ParquetFileFormat
27
+ import org .apache .spark .sql .execution .datasources .v2 .{DataSourceV2Relation , FileTable }
28
+ import org .apache .spark .sql .execution .datasources .v2 .orc .OrcTable
26
29
import org .apache .spark .sql .internal .SQLConf
27
30
import org .apache .spark .sql .types .{ArrayType , DataType , MapType , StructField , StructType }
28
31
@@ -48,7 +51,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
48
51
l @ LogicalRelation (hadoopFsRelation : HadoopFsRelation , _, _, _))
49
52
if canPruneRelation(hadoopFsRelation) =>
50
53
val (normalizedProjects, normalizedFilters) =
51
- normalizeAttributeRefNames(l, projects, filters)
54
+ normalizeAttributeRefNames(l.output , projects, filters)
52
55
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
53
56
54
57
// If requestedRootFields includes a nested field, continue. Otherwise,
@@ -76,6 +79,43 @@ object SchemaPruning extends Rule[LogicalPlan] {
76
79
} else {
77
80
op
78
81
}
82
+
83
+ case op @ PhysicalOperation (projects, filters,
84
+ d @ DataSourceV2Relation (table : FileTable , output, _)) if canPruneTable(table) =>
85
+ val (normalizedProjects, normalizedFilters) =
86
+ normalizeAttributeRefNames(output, projects, filters)
87
+ val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)
88
+
89
+ // If requestedRootFields includes a nested field, continue. Otherwise,
90
+ // return op
91
+ if (requestedRootFields.exists { root : RootField => ! root.derivedFromAtt }) {
92
+ val dataSchema = table.dataSchema
93
+ val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields)
94
+
95
+ // If the data schema is different from the pruned data schema, continue. Otherwise,
96
+ // return op. We effect this comparison by counting the number of "leaf" fields in
97
+ // each schemata, assuming the fields in prunedDataSchema are a subset of the fields
98
+ // in dataSchema.
99
+ if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
100
+ val prunedFileTable = table match {
101
+ case o : OrcTable => o.copy(userSpecifiedSchema = Some (prunedDataSchema))
102
+ case _ =>
103
+ val message = s " ${table.formatName} data source doesn't support schema pruning. "
104
+ throw new AnalysisException (message)
105
+ }
106
+
107
+
108
+ val prunedRelationV2 = buildPrunedRelationV2(d, prunedFileTable)
109
+ val projectionOverSchema = ProjectionOverSchema (prunedDataSchema)
110
+
111
+ buildNewProjection(normalizedProjects, normalizedFilters, prunedRelationV2,
112
+ projectionOverSchema)
113
+ } else {
114
+ op
115
+ }
116
+ } else {
117
+ op
118
+ }
79
119
}
80
120
81
121
/**
@@ -85,16 +125,22 @@ object SchemaPruning extends Rule[LogicalPlan] {
85
125
fsRelation.fileFormat.isInstanceOf [ParquetFileFormat ] ||
86
126
fsRelation.fileFormat.isInstanceOf [OrcFileFormat ]
87
127
128
+ /**
129
+ * Checks to see if the given [[FileTable ]] can be pruned. Currently we support ORC v2.
130
+ */
131
+ private def canPruneTable (table : FileTable ) =
132
+ table.isInstanceOf [OrcTable ]
133
+
88
134
/**
89
135
* Normalizes the names of the attribute references in the given projects and filters to reflect
90
136
* the names in the given logical relation. This makes it possible to compare attributes and
91
137
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
92
138
*/
93
139
private def normalizeAttributeRefNames (
94
- logicalRelation : LogicalRelation ,
140
+ output : Seq [ AttributeReference ] ,
95
141
projects : Seq [NamedExpression ],
96
142
filters : Seq [Expression ]): (Seq [NamedExpression ], Seq [Expression ]) = {
97
- val normalizedAttNameMap = logicalRelation. output.map(att => (att.exprId, att.name)).toMap
143
+ val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
98
144
val normalizedProjects = projects.map(_.transform {
99
145
case att : AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
100
146
att.withName(normalizedAttNameMap(att.exprId))
@@ -107,11 +153,13 @@ object SchemaPruning extends Rule[LogicalPlan] {
107
153
}
108
154
109
155
/**
110
- * Builds the new output [[Project ]] Spark SQL operator that has the pruned output relation .
156
+ * Builds the new output [[Project ]] Spark SQL operator that has the `leafNode` .
111
157
*/
112
158
private def buildNewProjection (
113
- projects : Seq [NamedExpression ], filters : Seq [Expression ], prunedRelation : LogicalRelation ,
114
- projectionOverSchema : ProjectionOverSchema ) = {
159
+ projects : Seq [NamedExpression ],
160
+ filters : Seq [Expression ],
161
+ leafNode : LeafNode ,
162
+ projectionOverSchema : ProjectionOverSchema ): Project = {
115
163
// Construct a new target for our projection by rewriting and
116
164
// including the original filters where available
117
165
val projectionChild =
@@ -120,9 +168,9 @@ object SchemaPruning extends Rule[LogicalPlan] {
120
168
case projectionOverSchema(expr) => expr
121
169
})
122
170
val newFilterCondition = projectedFilters.reduce(And )
123
- Filter (newFilterCondition, prunedRelation )
171
+ Filter (newFilterCondition, leafNode )
124
172
} else {
125
- prunedRelation
173
+ leafNode
126
174
}
127
175
128
176
// Construct the new projections of our Project by
@@ -145,20 +193,36 @@ object SchemaPruning extends Rule[LogicalPlan] {
145
193
private def buildPrunedRelation (
146
194
outputRelation : LogicalRelation ,
147
195
prunedBaseRelation : HadoopFsRelation ) = {
196
+ val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
197
+ outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
198
+ }
199
+
200
+ /**
201
+ * Builds a pruned data source V2 relation from the output of the relation and the schema
202
+ * of the pruned [[FileTable ]].
203
+ */
204
+ private def buildPrunedRelationV2 (
205
+ outputRelation : DataSourceV2Relation ,
206
+ prunedFileTable : FileTable ) = {
207
+ val prunedOutput = getPrunedOutput(outputRelation.output, prunedFileTable.schema)
208
+ outputRelation.copy(table = prunedFileTable, output = prunedOutput)
209
+ }
210
+
211
+ // Prune the given output to make it consistent with `requiredSchema`.
212
+ private def getPrunedOutput (
213
+ output : Seq [AttributeReference ],
214
+ requiredSchema : StructType ): Seq [AttributeReference ] = {
148
215
// We need to replace the expression ids of the pruned relation output attributes
149
216
// with the expression ids of the original relation output attributes so that
150
217
// references to the original relation's output are not broken
151
- val outputIdMap = outputRelation.output.map(att => (att.name, att.exprId)).toMap
152
- val prunedRelationOutput =
153
- prunedBaseRelation
154
- .schema
155
- .toAttributes
156
- .map {
157
- case att if outputIdMap.contains(att.name) =>
158
- att.withExprId(outputIdMap(att.name))
159
- case att => att
160
- }
161
- outputRelation.copy(relation = prunedBaseRelation, output = prunedRelationOutput)
218
+ val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
219
+ requiredSchema
220
+ .toAttributes
221
+ .map {
222
+ case att if outputIdMap.contains(att.name) =>
223
+ att.withExprId(outputIdMap(att.name))
224
+ case att => att
225
+ }
162
226
}
163
227
164
228
/**
0 commit comments