Skip to content

Commit df4c53e

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-26673][SQL] File source V2 writes: create framework and migrate ORC
## What changes were proposed in this pull request? Create a framework for write path of File Source V2. Also, migrate write path of ORC to V2. Supported: * Write to file as Dataframe Not Supported: * Partitioning, which is still under development in the data source V2 project. * Bucketing, which is still under development in the data source V2 project. * Catalog. ## How was this patch tested? Unit test Closes apache#23601 from gengliangwang/orc_write. Authored-by: Gengliang Wang <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent b3b62ba commit df4c53e

File tree

16 files changed

+486
-48
lines changed

16 files changed

+486
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,6 +1440,14 @@ object SQLConf {
14401440
.stringConf
14411441
.createWithDefault("")
14421442

1443+
val USE_V1_SOURCE_WRITER_LIST = buildConf("spark.sql.sources.write.useV1SourceList")
1444+
.internal()
1445+
.doc("A comma-separated list of data source short names or fully qualified data source" +
1446+
" register class names for which data source V2 write paths are disabled. Writes from these" +
1447+
" sources will fall back to the V1 sources.")
1448+
.stringConf
1449+
.createWithDefault("")
1450+
14431451
val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers")
14441452
.doc("A comma-separated list of fully qualified data source register class names for which" +
14451453
" StreamWriteSupport is disabled. Writes to these sources will fall back to the V1 Sinks.")
@@ -2026,6 +2034,8 @@ class SQLConf extends Serializable with Logging {
20262034

20272035
def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST)
20282036

2037+
def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST)
2038+
20292039
def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS)
20302040

20312041
def disabledV2StreamingMicroBatchReaders: String =

sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable,
2929
import org.apache.spark.sql.execution.SQLExecution
3030
import org.apache.spark.sql.execution.command.DDLUtils
3131
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
32-
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, WriteToDataSourceV2}
32+
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2}
3333
import org.apache.spark.sql.sources.BaseRelation
3434
import org.apache.spark.sql.sources.v2._
3535
import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode
@@ -243,8 +243,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
243243
assertNotBucketed("save")
244244

245245
val session = df.sparkSession
246-
val cls = DataSource.lookupDataSource(source, session.sessionState.conf)
247-
if (classOf[TableProvider].isAssignableFrom(cls)) {
246+
val useV1Sources =
247+
session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",")
248+
val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf)
249+
val cls = lookupCls.newInstance() match {
250+
case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) ||
251+
useV1Sources.contains(lookupCls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
252+
f.fallBackFileFormat
253+
case _ => lookupCls
254+
}
255+
// In Data Source V2 project, partitioning is still under development.
256+
// Here we fallback to V1 if partitioning columns are specified.
257+
// TODO(SPARK-26778): use V2 implementations when partitioning feature is supported.
258+
if (classOf[TableProvider].isAssignableFrom(cls) && partitioningColumns.isEmpty) {
248259
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
249260
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
250261
provider, session.sessionState.conf)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ object DataSource extends Logging {
763763
* supplied schema is not empty.
764764
* @param schema
765765
*/
766-
private def validateSchema(schema: StructType): Unit = {
766+
def validateSchema(schema: StructType): Unit = {
767767
def hasEmptySchema(schema: StructType): Boolean = {
768768
schema.size == 0 || schema.find {
769769
case StructField(_, b: StructType, _, _) => hasEmptySchema(b)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
2828
* Replace the ORC V2 data source of table in [[InsertIntoTable]] to V1 [[FileFormat]].
2929
* E.g, with temporary view `t` using [[FileDataSourceV2]], inserting into view `t` fails
3030
* since there is no corresponding physical plan.
31-
* SPARK-23817: This is a temporary hack for making current data source V2 work. It should be
32-
* removed when write path of file data source v2 is finished.
31+
* This is a temporary hack for making current data source V2 work. It should be
32+
* removed when Catalog support of file data source v2 is finished.
3333
*/
3434
class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
3535
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
2828
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
2929
import org.apache.spark.sql.catalyst.expressions._
30+
import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage}
3031
import org.apache.spark.sql.types.StringType
3132
import org.apache.spark.util.SerializableConfiguration
3233

@@ -37,7 +38,7 @@ import org.apache.spark.util.SerializableConfiguration
3738
abstract class FileFormatDataWriter(
3839
description: WriteJobDescription,
3940
taskAttemptContext: TaskAttemptContext,
40-
committer: FileCommitProtocol) {
41+
committer: FileCommitProtocol) extends DataWriter[InternalRow] {
4142
/**
4243
* Max number of files a single task writes out due to file size. In most cases the number of
4344
* files written should be very small. This is just a safe guard to protect some really bad
@@ -70,7 +71,7 @@ abstract class FileFormatDataWriter(
7071
* to the driver and used to update the catalog. Other information will be sent back to the
7172
* driver too and used to e.g. update the metrics in UI.
7273
*/
73-
def commit(): WriteTaskResult = {
74+
override def commit(): WriteTaskResult = {
7475
releaseResources()
7576
val summary = ExecutedWriteSummary(
7677
updatedPartitions = updatedPartitions.toSet,
@@ -301,6 +302,7 @@ class WriteJobDescription(
301302

302303
/** The result of a successful write task. */
303304
case class WriteTaskResult(commitMsg: TaskCommitMessage, summary: ExecutedWriteSummary)
305+
extends WriterCommitMessage
304306

305307
/**
306308
* Wrapper class for the metrics of writing data out.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ object FileFormatWriter extends Logging {
259259
* For every registered [[WriteJobStatsTracker]], call `processStats()` on it, passing it
260260
* the corresponding [[WriteTaskStats]] from all executors.
261261
*/
262-
private def processStats(
262+
private[datasources] def processStats(
263263
statsTrackers: Seq[WriteJobStatsTracker],
264264
statsPerTask: Seq[Seq[WriteTaskStats]])
265265
: Unit = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.execution.datasources.OutputWriter
2929
import org.apache.spark.sql.types._
3030

31-
private[orc] class OrcOutputWriter(
31+
private[sql] class OrcOutputWriter(
3232
path: String,
3333
dataSchema: StructType,
3434
context: TaskAttemptContext)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.datasources.v2
18+
19+
import org.apache.hadoop.mapreduce.Job
20+
21+
import org.apache.spark.internal.Logging
22+
import org.apache.spark.internal.io.FileCommitProtocol
23+
import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult}
24+
import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats
25+
import org.apache.spark.sql.sources.v2.writer._
26+
import org.apache.spark.util.SerializableConfiguration
27+
28+
class FileBatchWrite(
29+
job: Job,
30+
description: WriteJobDescription,
31+
committer: FileCommitProtocol)
32+
extends BatchWrite with Logging {
33+
override def commit(messages: Array[WriterCommitMessage]): Unit = {
34+
val results = messages.map(_.asInstanceOf[WriteTaskResult])
35+
committer.commitJob(job, results.map(_.commitMsg))
36+
logInfo(s"Write Job ${description.uuid} committed.")
37+
38+
processStats(description.statsTrackers, results.map(_.summary.stats))
39+
logInfo(s"Finished processing stats for write job ${description.uuid}.")
40+
}
41+
42+
override def useCommitCoordinator(): Boolean = false
43+
44+
override def abort(messages: Array[WriterCommitMessage]): Unit = {
45+
committer.abortJob(job)
46+
}
47+
48+
override def createBatchWriterFactory(): DataWriterFactory = {
49+
val conf = new SerializableConfiguration(job.getConfiguration)
50+
FileWriterFactory(description, committer, conf)
51+
}
52+
}
53+

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ import org.apache.hadoop.fs.FileStatus
2020

2121
import org.apache.spark.sql.{AnalysisException, SparkSession}
2222
import org.apache.spark.sql.execution.datasources._
23-
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, Table}
23+
import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table}
2424
import org.apache.spark.sql.types.StructType
2525

2626
abstract class FileTable(
2727
sparkSession: SparkSession,
2828
fileIndex: PartitioningAwareFileIndex,
29-
userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead {
29+
userSpecifiedSchema: Option[StructType])
30+
extends Table with SupportsBatchRead with SupportsBatchWrite {
3031
def getFileIndex: PartitioningAwareFileIndex = this.fileIndex
3132

3233
lazy val dataSchema: StructType = userSpecifiedSchema.orElse {
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.datasources.v2
18+
19+
import java.util.UUID
20+
21+
import scala.collection.JavaConverters._
22+
23+
import org.apache.hadoop.conf.Configuration
24+
import org.apache.hadoop.fs.Path
25+
import org.apache.hadoop.mapreduce.Job
26+
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
27+
28+
import org.apache.spark.internal.io.FileCommitProtocol
29+
import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
30+
import org.apache.spark.sql.catalyst.InternalRow
31+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
32+
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription}
33+
import org.apache.spark.sql.execution.metric.SQLMetric
34+
import org.apache.spark.sql.internal.SQLConf
35+
import org.apache.spark.sql.sources.v2.DataSourceOptions
36+
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
37+
import org.apache.spark.sql.types.StructType
38+
import org.apache.spark.util.SerializableConfiguration
39+
40+
abstract class FileWriteBuilder(options: DataSourceOptions)
41+
extends WriteBuilder with SupportsSaveMode {
42+
private var schema: StructType = _
43+
private var queryId: String = _
44+
private var mode: SaveMode = _
45+
46+
override def withInputDataSchema(schema: StructType): WriteBuilder = {
47+
this.schema = schema
48+
this
49+
}
50+
51+
override def withQueryId(queryId: String): WriteBuilder = {
52+
this.queryId = queryId
53+
this
54+
}
55+
56+
override def mode(mode: SaveMode): WriteBuilder = {
57+
this.mode = mode
58+
this
59+
}
60+
61+
override def buildForBatch(): BatchWrite = {
62+
validateInputs()
63+
val pathName = options.paths().head
64+
val path = new Path(pathName)
65+
val sparkSession = SparkSession.active
66+
val optionsAsScala = options.asMap().asScala.toMap
67+
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala)
68+
val job = getJobInstance(hadoopConf, path)
69+
val committer = FileCommitProtocol.instantiate(
70+
sparkSession.sessionState.conf.fileCommitProtocolClass,
71+
jobId = java.util.UUID.randomUUID().toString,
72+
outputPath = pathName)
73+
lazy val description =
74+
createWriteJobDescription(sparkSession, hadoopConf, job, pathName, optionsAsScala)
75+
76+
val fs = path.getFileSystem(hadoopConf)
77+
mode match {
78+
case SaveMode.ErrorIfExists if fs.exists(path) =>
79+
val qualifiedOutputPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory)
80+
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
81+
82+
case SaveMode.Ignore if fs.exists(path) =>
83+
null
84+
85+
case SaveMode.Overwrite =>
86+
committer.deleteWithJob(fs, path, true)
87+
committer.setupJob(job)
88+
new FileBatchWrite(job, description, committer)
89+
90+
case _ =>
91+
committer.setupJob(job)
92+
new FileBatchWrite(job, description, committer)
93+
}
94+
}
95+
96+
/**
97+
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
98+
* be put here. For example, user defined output committer can be configured here
99+
* by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
100+
*/
101+
def prepareWrite(
102+
sqlConf: SQLConf,
103+
job: Job,
104+
options: Map[String, String],
105+
dataSchema: StructType): OutputWriterFactory
106+
107+
private def validateInputs(): Unit = {
108+
assert(schema != null, "Missing input data schema")
109+
assert(queryId != null, "Missing query ID")
110+
assert(mode != null, "Missing save mode")
111+
assert(options.paths().length == 1)
112+
DataSource.validateSchema(schema)
113+
}
114+
115+
private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
116+
val job = Job.getInstance(hadoopConf)
117+
job.setOutputKeyClass(classOf[Void])
118+
job.setOutputValueClass(classOf[InternalRow])
119+
FileOutputFormat.setOutputPath(job, path)
120+
job
121+
}
122+
123+
private def createWriteJobDescription(
124+
sparkSession: SparkSession,
125+
hadoopConf: Configuration,
126+
job: Job,
127+
pathName: String,
128+
options: Map[String, String]): WriteJobDescription = {
129+
val caseInsensitiveOptions = CaseInsensitiveMap(options)
130+
// Note: prepareWrite has side effect. It sets "job".
131+
val outputWriterFactory =
132+
prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema)
133+
val allColumns = schema.toAttributes
134+
val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
135+
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
136+
val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
137+
// TODO: after partitioning is supported in V2:
138+
// 1. filter out partition columns in `dataColumns`.
139+
// 2. Don't use Seq.empty for `partitionColumns`.
140+
new WriteJobDescription(
141+
uuid = UUID.randomUUID().toString,
142+
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
143+
outputWriterFactory = outputWriterFactory,
144+
allColumns = allColumns,
145+
dataColumns = allColumns,
146+
partitionColumns = Seq.empty,
147+
bucketIdExpression = None,
148+
path = pathName,
149+
customPartitionLocations = Map.empty,
150+
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
151+
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
152+
timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
153+
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone),
154+
statsTrackers = Seq(statsTracker)
155+
)
156+
}
157+
}
158+

0 commit comments

Comments
 (0)